Skip to main content

Mojo module

config

Configuration for depth=512 pair-CTA SM100 (Blackwell) MHA kernels.

This config drives a fundamentally different kernel design from FA4: two neighboring SMs cooperate via cta_group=2, cluster_shape=(2,1,1). Each CTA processes BM=64 Q rows; the pair-CTA MMA instruction operates on MMA_M=128 combined rows. P@V uses SS MMA (P in SMEM, not TMEM).

K and V are extensively sub-staged to fit in SMEM:

  • Q@K': K sub-tiled along depth into num_qk_stages=4 chunks (BK0=128)
  • P@V: V sub-tiled along BN (reduction dim) into num_pv_stages=4 chunks

Structs

Was this page helpful?