Skip to main content

Mojo function

generic_flare_mla_prefill_kv_cache_ragged

generic_flare_mla_prefill_kv_cache_ragged[collection_t: KVCollectionT, input_dtype: DType, dtype: DType, //, mask_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin], local_window_size: Int = -1](q: TileTensor[input_dtype, q.LayoutType, q.origin, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[input_dtype, k.LayoutType, k.origin, linear_idx_type=k.linear_idx_type, element_size=k.element_size], v: TileTensor[input_dtype, v.LayoutType, v.origin, linear_idx_type=v.linear_idx_type, element_size=v.element_size], buffer_row_offsets: TileTensor[DType.uint32, buffer_row_offsets.LayoutType, buffer_row_offsets.origin, linear_idx_type=buffer_row_offsets.linear_idx_type, element_size=buffer_row_offsets.element_size], cache_offsets: TileTensor[DType.uint32, cache_offsets.LayoutType, cache_offsets.origin, linear_idx_type=cache_offsets.linear_idx_type, element_size=cache_offsets.element_size], input_row_offsets: TileTensor[DType.uint32, input_row_offsets.LayoutType, input_row_offsets.origin, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, output: TileTensor[dtype, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], context: DeviceContextPtr)

Was this page helpful?