Skip to main content

Mojo function

generic_flare_mla_prefill_kv_cache_ragged

generic_flare_mla_prefill_kv_cache_ragged[collection_t: KVCollectionT, input_dtype: DType, dtype: DType, //, mask_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin], local_window_size: Int = -1](q: LayoutTensor[input_dtype, q.layout, q.origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: LayoutTensor[input_dtype, k.layout, k.origin, element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[input_dtype, v.layout, v.origin, element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], buffer_row_offsets: LayoutTensor[DType.uint32, buffer_row_offsets.layout, buffer_row_offsets.origin, element_layout=buffer_row_offsets.element_layout, layout_int_type=buffer_row_offsets.layout_int_type, linear_idx_type=buffer_row_offsets.linear_idx_type, masked=buffer_row_offsets.masked, alignment=buffer_row_offsets.alignment], cache_offsets: LayoutTensor[DType.uint32, cache_offsets.layout, cache_offsets.origin, element_layout=cache_offsets.element_layout, layout_int_type=cache_offsets.layout_int_type, linear_idx_type=cache_offsets.linear_idx_type, masked=cache_offsets.masked, alignment=cache_offsets.alignment], input_row_offsets: LayoutTensor[DType.uint32, input_row_offsets.layout, input_row_offsets.origin, element_layout=input_row_offsets.element_layout, layout_int_type=input_row_offsets.layout_int_type, linear_idx_type=input_row_offsets.linear_idx_type, masked=input_row_offsets.masked, alignment=input_row_offsets.alignment], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, output: LayoutTensor[dtype, output.layout, output.origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], context: DeviceContextPtr)

Was this page helpful?