Skip to main content

Mojo module

kernel

Kernel entry point for depth=256/512 pair-CTA SM100 (Blackwell) MHA prefill.

Two neighboring SMs cooperate via pair-CTA MMA (cta_group=2, cluster_shape=(2,1,1)).

Depth-dependent geometry: depth=512: MMA_M=128, BM=64, BN=256. O split into O_lo/O_hi. depth=256: MMA_M=256, BM=128, BN=128. Single O accumulator.

Warp assignment (384 threads = 12 warps, 3 warp groups of 128): Warps 0-3: Softmax (warp group 0) Warps 4-7: Correction (warp group 1) Warp 8: MMA (leader CTA issues pair-CTA MMA; peer early-returns) Warp 9: Load (both CTAs issue TMA multicast; leader calls expect_bytes) Warps 10-11: Spare (no-op)

Structs

Was this page helpful?