IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

k2q_csr_device

Device (GPU) reverse-CSR builder for KV-block-major sparse MHA.

GPU port of host k2q_csr.build_k2q_csr (its oracle). Inverts the query-major selection q2k [head_kv, total_q, topK] into the KV-block-major CSR + schedule the block-major forward/combine kernels consume, emitting the SAME contract tensors the host builder produces directly into device buffers (no host round- trip).

Five stages: row_map round-robin (batch, kv_block) -> row_linear + row_coords hist per-(CTA,warp) unit histograms -> tile_counts + row_counts row_prefix one block per head: row_counts -> row_ptr, emit scheduler_metadata tile_prefix scan tile_counts along the (CTA,warp) unit axis -> per-unit base scatter per-unit q-sequential write of qsplit / split_counts

The hist/scatter grid is (g, head_kv): heads run as parallel CTAs (grid.y) and the q-range is tiled across g CTAs x kwarps warps (g_total units), each owning a contiguous q-sub-range -- so g*head_kv CTAs spread the q*topk edge stream across the SMs (a single under-gridded CTA serializes it on one SM). Per-row slots are reserved by an exclusive prefix scan over the units (PR + PT), so scatter writes without cross-CTA atomics and the per-unit ranges concatenate to a globally q-ascending row, byte-identical to the host's sequential writer.

SMEM histogram/cursor entries are one Int32 per (warp,row) (no int16 bit-pack): no per-warp count cap, 2x the per-warp SMEM; kwarps is picked so two CTAs still fit per SM at the BF16/non-paged row counts. q_per_cta chunking: each non- empty row -> ceil(row_count/q_per_cta) work items, default 128 = the fwd CTA query cap (BM).

Structs​

Functions​