Skip to main content

Mojo module

kernel

Kernel entry point for depth=512 pair-CTA SM100 (Blackwell) MHA prefill.

Two neighboring SMs cooperate via pair-CTA MMA (cta_group=2, cluster_shape=(2,1,1)). Each CTA processes BM=64 Q rows; the pair-CTA MMA instruction operates on MMA_M=128 combined rows.

Warp assignment (384 threads = 12 warps, 3 warp groups of 128): Warps 0-3: Softmax (warp group 0) Warps 4-7: Correction (warp group 1) Warp 8: MMA (leader CTA issues pair-CTA MMA; peer early-returns) Warp 9: Load (both CTAs issue TMA multicast; leader calls expect_bytes) Warps 10-11: Spare (no-op)

Structs

Was this page helpful?