Mojo module
smem
Shared memory layout for SM100 attention kernels.
Encapsulates the smem offset calculations used by all FA4 warp-specialized functions (kernel, softmax, correction, load, mma) so that each consumer derives pointers from a single source of truth instead of duplicating the arithmetic.
Split-mode memory layout (low to high address): [Q: q_nope_bytes + q_rope_bytes] [K: num_kv_stages * (padded_ov_depthBNqkv_dt + rope_depthBNrope_dt)] [V: num_kv_stages * padded_ov_depth * BN elements of qkv_dtype] [correction: BM elements of Float32] [q_scale: BM * scale_dtype (0 when scale_dtype is invalid)] [k_scale: num_k_scale_bufs * BN * scale_dtype (0 when invalid)] [mbars: FA4MiscMBars.size SharedMemBarriers] [tmem_addr: 1 UInt32]
All K stages are contiguous, followed by all V stages contiguous.
Fused-mode memory layout (low to high address): [Q: BM * padded_qk_depth elements of qkv_dtype] [KV_fused: num_kv_stages * padded_ov_depth * BN elements of qkv_dtype] [Rope: ceil(num_kv_stages/2) * BN * rope_depth elements of qkv_dtype] [correction: BM elements of Float32] [q_scale: BM * scale_dtype (0 when scale_dtype is invalid)] [k_scale: num_k_scale_bufs * BN * scale_dtype (0 when invalid)] [mbars: FA4MiscMBars.size SharedMemBarriers] [tmem_addr: 1 UInt32]
In fused mode, K_nope and V alternate in the same buffer (padded_ov_depth wide), and rope data is stored separately at half the staging rate. k_smem_base() and v_smem_base() return the same pointer.
Structs
-
SM100AttentionSMem: Shared memory layout manager for SM100 Flash Attention kernels.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!