Skip to main content

Mojo function

generic_flare_mla_decode_kv_cache_ragged

generic_flare_mla_decode_kv_cache_ragged[collection_t: KVCollectionT, q_dtype: DType, //, mask_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin], local_window_size: Int = -1, per_token_scale_rope_aware: Bool = False](q: TileTensor[q_dtype, q.LayoutType, q.origin, linear_idx_type=q.linear_idx_type, element_size=q.element_size], input_row_offsets: TileTensor[DType.uint32, input_row_offsets.LayoutType, input_row_offsets.origin, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, output: TileTensor[output.dtype, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], scalar_args_buf: TileTensor[DType.int64, scalar_args_buf.LayoutType, scalar_args_buf.origin, linear_idx_type=scalar_args_buf.linear_idx_type, element_size=scalar_args_buf.element_size], context: DeviceContextPtr, q_scale_ptr: UnsafePointer[Float32, MutAnyOrigin] = UnsafePointer())

Was this page helpful?