IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

Struct_msa_attention_ragged_paged

struct Struct_msa_attention_ragged_paged

Implemented traits​

AnyType, ImplicitlyDeletable

Methods​

execute​

static def execute[*, group: Int, topk: Int](output: ManagedTensorSlice[Output, static_spec=output.static_spec], q: ManagedTensorSlice[Input, static_spec=q.static_spec], input_row_offsets: ManagedTensorSlice[Input, static_spec=input_row_offsets.static_spec], kv_blocks: ManagedTensorSlice[MutableInput, static_spec=kv_blocks.static_spec], cache_lengths: ManagedTensorSlice[Input, static_spec=cache_lengths.static_spec], kv_lookup_table: ManagedTensorSlice[Input, static_spec=kv_lookup_table.static_spec], max_lengths: ManagedTensorSlice[Input, static_spec=max_lengths.static_spec], layer_idx: UInt32, d_indices: ManagedTensorSlice[Input, static_spec=d_indices.static_spec], scale: Float32, ctx: DeviceContext)

Block-sparse MHA for SM100 (BF16, head_dim 128).

Gathers topk KV blocks per (kv head, query token) using the block ids in d_indices. Dispatches to the decode kernel when kv_collection.max_seq_length == 1 (one query token per sequence) and to the prefill kernel otherwise.

Decode uses NullMask + an SM-fill split-K heuristic (get_mha_decoding_max_num_partitions clamped by topk): `num_partitions

1runs the block-major fwd over partitioned KV bands then combines via the sharedmha_splitk_reduce; num_partitions == 1takes the no-combineNoPartition path. Prefill uses the device-CSR plan/run path (msa_sm100_prefill_plan+msa_sm100_prefill_run`): the run is pure-device, but the plan sizes its buffers from the per-batch cu-seqlens on host, so one D2H readback + sync per call is unavoidable while this stays a single stateless op.

Parameters:

  • ​group (Int): Query heads per kv-head (n_heads // n_kv_heads); asserts group <= MMA_M in the kernel.
  • ​topk (Int): Number of gathered KV blocks per token (d_indices stride).

Args: