IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

smem

Shared memory layout for depth=512 pair-CTA SM100 attention kernels.

Encapsulates the smem offset calculations used by all depth512 warp-specialized functions (kernel, softmax, correction, load, mma) so that each consumer derives pointers from a single source of truth instead of duplicating the arithmetic.

Memory layout (low to high address): [Q: BM * qk_depth elements of qkv_dtype] (reused as O output) [P: BM * BN elements of qkv_dtype] (SS MMA P@V buffer) [KV: num_kv_stages * (BN//2) * BK0 elements of qkv_dtype] [correction: BM elements of Float32] [barriers: (num_fixed + 2*num_kv_stages) SharedMemBarriers] [tmem_addr: 1 UInt32]

The P buffer is unique to this kernel: P@V uses SS MMA (both operands from SMEM), so softmax must write P to SMEM after computing exp(S). In FA4, P lives in TMEM and P@V uses TS MMA.

KV sub-tiles are fused: each buffer slot holds (BN//2)*BK0 elements, interpreted as K (BN//2 × BK0) during Q@K' or V half (BK1 × ov_depth//4) during P@V. V_lo and V_hi each occupy separate pipeline slots. These have equal element count when (BN//2)BK0 == BK1(ov_depth//4).

Structs