Skip to main content

Mojo module

smem

Shared memory layout for SM100 attention kernels.

Encapsulates the smem offset calculations used by all FA4 warp-specialized functions (kernel, softmax, correction, load, mma) so that each consumer derives pointers from a single source of truth instead of duplicating the arithmetic.

Split-mode memory layout (low to high address): [Q: q_nope_bytes + q_rope_bytes] [K: num_kv_stages * (padded_ov_depthBNqkv_dt + rope_depthBNrope_dt)] [V: num_kv_stages * padded_ov_depth * BN elements of qkv_dtype] [correction: BM elements of Float32] [q_scale: BM * scale_dtype (0 when scale_dtype is invalid)] [k_scale: num_k_scale_bufs * BN * scale_dtype (0 when invalid)] [mbars: FA4MiscMBars.size SharedMemBarriers] [tmem_addr: 1 UInt32]

All K stages are contiguous, followed by all V stages contiguous.

Fused-mode memory layout (low to high address): [Q: BM * padded_qk_depth elements of qkv_dtype] [KV_fused: num_kv_stages * padded_ov_depth * BN elements of qkv_dtype] [Rope: ceil(num_kv_stages/2) * BN * rope_depth elements of qkv_dtype] [correction: BM elements of Float32] [q_scale: BM * scale_dtype (0 when scale_dtype is invalid)] [k_scale: num_k_scale_bufs * BN * scale_dtype (0 when invalid)] [mbars: FA4MiscMBars.size SharedMemBarriers] [tmem_addr: 1 UInt32]

In fused mode, K_nope and V alternate in the same buffer (padded_ov_depth wide), and rope data is stored separately at half the staging rate. k_smem_base() and v_smem_base() return the same pointer.

Structs

Was this page helpful?