Skip to main content

Mojo function

flare_mla_decoding_dispatch

flare_mla_decoding_dispatch[k_t: MHAOperand, mask_t: MHAMask, dtype: DType, //, kv_num_heads: Int, config: MHAConfig[dtype], ragged: Bool = False, _is_cache_length_accurate: Bool = False, _use_valid_length: Bool = True, decoding_warp_split_k: Bool = False, per_token_scale_rope_aware: Bool = False, sparse: Bool = False, rope_aware_kv_sparse: Bool = False](output: TileTensor[linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: k_t, mask_functor: mask_t, valid_length: TileTensor[DType.uint32, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], max_prompt_len: Int, max_cache_valid_length: Int, scale: Float32, ctx: DeviceContext, scalar_args_buf: NullableTileTensor[DType.int64, linear_idx_type=scalar_args_buf.linear_idx_type, element_size=scalar_args_buf.element_size], kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin]] = None, num_partitions: Optional[Int] = None, q_scale_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None, d_indices: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, indices_stride: Int = 0, topk_lengths: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None, extra_k: OptionalReg[k_t] = None, extra_d_indices: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, extra_indices_stride: Int = 0, extra_topk_lengths: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, extra_scales_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]] = None)