Skip to main content

Mojo function

generic_flare_mla_decode_kv_cache_ragged

generic_flare_mla_decode_kv_cache_ragged[collection_t: KVCollectionT, q_dtype: DType, //, mask_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin], local_window_size: Int = -1, per_token_scale_rope_aware: Bool = False](q: LayoutTensor[q_dtype, q.layout, q.origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], input_row_offsets: LayoutTensor[DType.uint32, input_row_offsets.layout, input_row_offsets.origin, element_layout=input_row_offsets.element_layout, layout_int_type=input_row_offsets.layout_int_type, linear_idx_type=input_row_offsets.linear_idx_type, masked=input_row_offsets.masked, alignment=input_row_offsets.alignment], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, output: LayoutTensor[output.dtype, output.layout, output.origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], scalar_args_buf: LayoutTensor[DType.int64, scalar_args_buf.layout, scalar_args_buf.origin, element_layout=scalar_args_buf.element_layout, layout_int_type=scalar_args_buf.layout_int_type, linear_idx_type=scalar_args_buf.linear_idx_type, masked=scalar_args_buf.masked, alignment=scalar_args_buf.alignment], context: DeviceContextPtr, q_scale_ptr: UnsafePointer[Float32, MutAnyOrigin] = UnsafePointer())

Was this page helpful?