Mojo module
kernel
Kernel entry point for depth=512 pair-CTA SM100 (Blackwell) MHA prefill.
Two neighboring SMs cooperate via pair-CTA MMA (cta_group=2, cluster_shape=(2,1,1)). Each CTA processes BM=64 Q rows; the pair-CTA MMA instruction operates on MMA_M=128 combined rows.
Warp assignment (384 threads = 12 warps, 3 warp groups of 128): Warps 0-3: Softmax (warp group 0) Warps 4-7: Correction (warp group 1) Warp 8: MMA (leader CTA issues pair-CTA MMA; peer early-returns) Warp 9: Load (both CTAs issue TMA multicast; leader calls expect_bytes) Warps 10-11: Spare (no-op)
Structs
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!