For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
msa_combine_dispatch
def msa_combine_dispatch[output_type: DType, //, depth: Int, max_splits: Int = Int(32)](o: DeviceBuffer[output_type], lse: DeviceBuffer[DType.float32], o_partial: DeviceBuffer[output_type], lse_partial: DeviceBuffer[DType.float32], split_counts: UnsafePointer[Int32, MutAnyOrigin], cu_seqlens_q: UnsafePointer[Int32, MutAnyOrigin], batch: Int, head_q: Int, head_kv: Int, max_seqlen_q: Int, total_q: Int, topk: Int, ctx: DeviceContext)
Launches the MSA combine: grid (ceil(max_seqlen_q*head_q/64), 1, batch), one CTA per row-tile per batch, 256 threads each. Reduces the O_partial/LSE_partial split slots into the final O / LSE. max_splits is the comptime SMEM/register sizing (>= the runtime topk); the default 32 covers all currently dispatched topk values.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!