Mojo function
generic_flare_mla_decode_kv_cache_ragged
generic_flare_mla_decode_kv_cache_ragged[collection_t: KVCollectionT, q_dtype: DType, //, mask_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin], local_window_size: Int = -1, per_token_scale_rope_aware: Bool = False, sparse_mla: Bool = False](q: TileTensor[q_dtype, linear_idx_type=q.linear_idx_type, element_size=q.element_size], input_row_offsets: TileTensor[DType.uint32, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, output: TileTensor[linear_idx_type=output.linear_idx_type, element_size=output.element_size], scalar_args_buf: TileTensor[DType.int64, linear_idx_type=scalar_args_buf.linear_idx_type, element_size=scalar_args_buf.element_size], context: DeviceContextPtr, q_scale_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None, d_indices: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, indices_stride: Int = 0, topk_lengths: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None, extra_k: OptionalReg[collection_t.CacheType] = None, extra_d_indices: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, extra_indices_stride: Int = 0, extra_topk_lengths: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, extra_scales_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!