IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

smem

Shared memory layout for SM100 attention kernels.

Encapsulates the smem offset calculations used by all FA4 warp-specialized functions (kernel, softmax, correction, load, mma) so that each consumer derives pointers from a single source of truth instead of duplicating the arithmetic.

Split-mode memory layout (low to high address): [Q: q_nope_bytes + q_rope_bytes] [K: num_kv_stages * (padded_ov_depthBNqkv_dt + rope_depthBNrope_dt)] [V: num_kv_stages * padded_ov_depth * BN elements of qkv_dtype] [correction: 2 * WARPGROUP_SIZE Float32 entries (= BM in 2Q; doubled to 2*BM in 1Q so each softmax thread tid in [0, 255] has a dedicated slot)] [q_scale: BM * scale_dtype (0 when scale_dtype is invalid)] [k_scale: num_k_scale_bufs * BN * scale_dtype (0 when invalid)] [mbars: FA4MiscMBars.size SharedMemBarriers] [tmem_addr: 1 UInt32]

All K stages are contiguous, followed by all V stages contiguous.

Fused-mode memory layout (low to high address): [Q: BM * padded_qk_depth elements of qkv_dtype] [KV_fused: num_kv_stages * padded_ov_depth * BN elements of qkv_dtype] [Rope: ceil(num_kv_stages/2) * BN * rope_depth elements of qkv_dtype] [correction: 2 * WARPGROUP_SIZE Float32 entries (= BM in 2Q; doubled to 2*BM in 1Q so each softmax thread tid in [0, 255] has a dedicated slot)] [q_scale: BM * scale_dtype (0 when scale_dtype is invalid)] [k_scale: num_k_scale_bufs * BN * scale_dtype (0 when invalid)] [mbars: FA4MiscMBars.size SharedMemBarriers] [tmem_addr: 1 UInt32]

In fused mode, K_nope and V alternate in the same buffer (padded_ov_depth wide), and rope data is stored separately at half the staging rate. k_smem_base() and v_smem_base() return the same pointer.

In num_qo == 1 mode the cross-WG LSE exchange runs through the (now-dead) s TMEM slot rather than smem, so no additional smem region is needed. Both warpgroups still write the combined LSE-reduced output to the single q-aliased o_smem region, then TMA-store it to gmem. Output partials remain in TMEM throughout the combine.

Structs