Skip to main content

Mojo module

config

Configuration for pair-CTA SM100 (Blackwell) MHA kernels (depth 256/512).

This config drives a fundamentally different kernel design from FA4: two neighboring SMs cooperate via cta_group=2, cluster_shape=(2,1,1). P@V uses SS MMA (P in SMEM, not TMEM).

Depth-dependent geometry (MMA_M, BM, BN, num_qk_stages): depth=512: MMA_M=128, BM=64, BN=256, num_qk_stages=4, split O into O_lo/O_hi depth=256: MMA_M=256, BM=128, BN=128, num_qk_stages=2, single O

K and V are extensively sub-staged to fit in SMEM:

  • Q@K': K sub-tiled along depth into num_qk_stages chunks (BK0=128)
  • P@V: V sub-tiled along BN (reduction dim) into num_pv_stages=2 chunks

Structs

Was this page helpful?