For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
Struct_msa_attention_ragged_paged
struct Struct_msa_attention_ragged_paged
Implemented traitsβ
Methodsβ
executeβ
static def execute[*, group: Int, topk: Int](output: ManagedTensorSlice[Output, static_spec=output.static_spec], q: ManagedTensorSlice[Input, static_spec=q.static_spec], input_row_offsets: ManagedTensorSlice[Input, static_spec=input_row_offsets.static_spec], kv_blocks: ManagedTensorSlice[MutableInput, static_spec=kv_blocks.static_spec], cache_lengths: ManagedTensorSlice[Input, static_spec=cache_lengths.static_spec], kv_lookup_table: ManagedTensorSlice[Input, static_spec=kv_lookup_table.static_spec], max_lengths: ManagedTensorSlice[Input, static_spec=max_lengths.static_spec], layer_idx: UInt32, d_indices: ManagedTensorSlice[Input, static_spec=d_indices.static_spec], scale: Float32, ctx: DeviceContext)
Block-sparse MHA for SM100 (BF16, head_dim 128).
Gathers topk KV blocks per (kv head, query token) using the block ids
in d_indices. Dispatches to the decode kernel when
kv_collection.max_seq_length == 1 (one query token per sequence) and to
the prefill kernel otherwise.
Decode uses NullMask + an SM-fill split-K heuristic
(get_mha_decoding_max_num_partitions clamped by topk): `num_partitions
1
runs the block-major fwd over partitioned KV bands then combines via the sharedmha_splitk_reduce;num_partitions == 1takes the no-combineNoPartitionpath. Prefill uses the device-CSR plan/run path (msa_sm100_prefill_plan+msa_sm100_prefill_run`): the run is pure-device, but the plan sizes its buffers from the per-batch cu-seqlens on host, so one D2H readback + sync per call is unavoidable while this stays a single stateless op.
Parameters:
- βgroup (
Int): Query heads per kv-head (n_heads // n_kv_heads); assertsgroup <= MMA_Min the kernel. - βtopk (
Int): Number of gathered KV blocks per token (d_indicesstride).
Args:
- βoutput (
ManagedTensorSlice[Output, static_spec=output.static_spec]): Output[num_rows, n_heads, head_dim]BF16. - βq (
ManagedTensorSlice[Input, static_spec=q.static_spec]): Query[num_rows, n_heads, head_dim]BF16 (num_rows== total_q on prefill, batch on decode). - βinput_row_offsets (
ManagedTensorSlice[Input, static_spec=input_row_offsets.static_spec]): Ragged query offsets[batch + 1]uint32 (1 token/seq on decode). - βkv_blocks (
ManagedTensorSlice[MutableInput, static_spec=kv_blocks.static_spec]): Main-KV paged blocks[num_blocks, 2, num_layers, page_size, n_kv_heads, head_dim]BF16. - βcache_lengths (
ManagedTensorSlice[Input, static_spec=cache_lengths.static_spec]): Main-KV cache lengths[batch]uint32. - βkv_lookup_table (
ManagedTensorSlice[Input, static_spec=kv_lookup_table.static_spec]): Main-KV page table[batch, max_pages]uint32. - βmax_lengths (
ManagedTensorSlice[Input, static_spec=max_lengths.static_spec]): Main-KV max lengths[1, 2]uint32. - βlayer_idx (
UInt32): Layer index for the main-KV cache. - βd_indices (
ManagedTensorSlice[Input, static_spec=d_indices.static_spec]): Selected block ids[n_kv_heads, num_rows, topk]int32. - βscale (
Float32): QK scale. - βctx (
DeviceContext): Device context.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!