Mojo module
smem
Shared memory layout for depth=512 pair-CTA SM100 attention kernels.
Encapsulates the smem offset calculations used by all depth512 warp-specialized functions (kernel, softmax, correction, load, mma) so that each consumer derives pointers from a single source of truth instead of duplicating the arithmetic.
Memory layout (low to high address): [Q: BM * qk_depth elements of qkv_dtype] (reused as O output) [P: BM * BN elements of qkv_dtype] (SS MMA P@V buffer) [KV: num_kv_stages * (BN//2) * BK0 elements of qkv_dtype] [correction: BM elements of Float32] [barriers: (8 + 2*num_kv_stages) SharedMemBarriers] [tmem_addr: 1 UInt32]
The P buffer is unique to this kernel: P@V uses SS MMA (both operands from SMEM), so softmax must write P to SMEM after computing exp(S). In FA4, P lives in TMEM and P@V uses TS MMA.
KV sub-tiles are fused: each buffer slot holds (BN//2)*BK0 elements, interpreted as K (BN//2 × BK0) during Q@K' or V half (BK1 × ov_depth//4) during P@V. V_lo and V_hi each occupy separate pipeline slots. These have equal element count when (BN//2)BK0 == BK1(ov_depth//4).
Structs
-
Depth512AttentionSMem: Shared memory layout manager for depth=512 pair-CTA attention kernels.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!