Mojo module
kernel
Kernel entry point for depth=256/512 pair-CTA SM100 (Blackwell) MHA prefill.
Two neighboring SMs cooperate via pair-CTA MMA (cta_group=2, cluster_shape=(2,1,1)).
Depth-dependent geometry: depth=512: MMA_M=128, BM=64, BN=256. O split into O_lo/O_hi. depth=256: MMA_M=256, BM=128, BN=128. Single O accumulator.
Warp assignment (384 threads = 12 warps, 3 warp groups of 128): Warps 0-3: Softmax (warp group 0) Warps 4-7: Correction (warp group 1) Warp 8: MMA (leader CTA issues pair-CTA MMA; peer early-returns) Warp 9: Load (both CTAs issue TMA multicast; leader calls expect_bytes) Warps 10-11: Spare (no-op)
Structs
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!