Skip to main content

Mojo module

load_warp

TMA load warp logic for depth=512 pair-CTA SM100 attention.

Each CTA in the pair loads its own half of K/V data into its local SMEM. The pair-CTA MMA instruction reads from both SMs' SMEM to combine the halves.

K is split along BN rows: even CTA loads K[0:BN//2, :], odd loads K[BN//2:BN, :]. V is split into V_lo and V_hi (separate pipeline slots, each [BK1, ov_depth//4]): V_lo: even loads V[:, 0:ov_depth//4], odd loads V[:, ov_depth//4:ov_depth//2] V_hi: even loads V[:, ov_depth//2:3ov_depth//4], odd loads V[:, 3ov_depth//4:ov_depth] Q is per-CTA: even loads Q[0:64, :], odd loads Q[64:128, :].

All TMA loads use async_multicast_load_3d[cta_group=2] with a per-CTA mask. The cta_group=2 ensures the leader CTA's barrier tracks byte arrivals from both CTAs. Only the leader CTA calls expect_bytes and wait.

Mask computations use PairBM (BM*2=128) so both CTAs make identical skip decisions. If one CTA skips a tile and the other doesn't, barriers desync.

Functions

Was this page helpful?