Mojo module
config
Configuration for depth=512 pair-CTA SM100 (Blackwell) MHA kernels.
This config drives a fundamentally different kernel design from FA4: two neighboring SMs cooperate via cta_group=2, cluster_shape=(2,1,1). Each CTA processes BM=64 Q rows; the pair-CTA MMA instruction operates on MMA_M=128 combined rows. P@V uses SS MMA (P in SMEM, not TMEM).
K and V are extensively sub-staged to fit in SMEM:
- Q@K': K sub-tiled along depth into num_qk_stages=4 chunks (BK0=128)
- P@V: V sub-tiled along BN (reduction dim) into num_pv_stages=4 chunks
Structs
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!