For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
msa_sm100_block_major_dispatch
def msa_sm100_block_major_dispatch[q_type: DType, KVType: MHAOperand, output_type: DType, //, config: MHAConfig[config.dtype], group: Int](o_partial: DeviceBuffer[output_type], lse_partial: DeviceBuffer[DType.float32], q_arg: DeviceBuffer[q_type], k: KVType, v: KVType, scheduler_metadata: UnsafePointer[Int32, MutAnyOrigin], grid_work: Int, work_count_ptr: UnsafePointer[Int32, MutAnyOrigin], k2q_row_ptr: UnsafePointer[Int32, MutAnyOrigin], qsplit_indices: UnsafePointer[Int32, MutAnyOrigin], cu_seqlens_k: UnsafePointer[Int32, MutAnyOrigin], cu_seqlens_q: UnsafePointer[Int32, MutAnyOrigin], total_q: Int, total_rows: Int, nnz: Int, head_q: Int, scale: Float32, ctx: DeviceContext, q_positions: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None, seqused_k: OptionalReg[UnsafePointer[Int32, MutAnyOrigin]] = None)
Dispatch for the block-major MSA forward. Grid is (grid_work, 1, 1) -- a host-known capacity >= the on-device work_count, so no host readback / sync is needed. CTAs past work_count_ptr[0] are idle and PDL-early-exit. Produces O_partial/LSE_partial; PDL (OVERLAP_AT_END) chains this fwd with the combine kernel that merges them into O.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!