Skip to main content

Mojo module

load_warp

TMA load warp logic for depth=256/512 pair-CTA SM100 attention.

Each CTA in the pair loads its own half of K/V data into its local SMEM. The pair-CTA MMA instruction reads from both SMs' SMEM to combine the halves.

K is split along BN rows: even CTA loads K[0:BN//2, :], odd loads K[BN//2:BN, :].

V loading depends on split_o (depth-dependent): split_o=True (depth=512): V split into V_lo and V_hi (separate pipeline slots): V_lo: even loads V[:, 0:ov_depth//4], odd loads V[:, ov_depth//4:ov_depth//2] V_hi: even loads V[:, ov_depth//2:3ov_depth//4], odd loads V[:, 3ov_depth//4:ov_depth] split_o=False (depth=256): Single V (no V_hi): V: even loads V[:, 0:ov_depth//2], odd loads V[:, ov_depth//2:ov_depth]

Q is per-CTA: even loads Q[0:BM, :], odd loads Q[BM:PairBM, :].

All TMA loads use async_multicast_load_3d[cta_group=2] with a per-CTA mask. The cta_group=2 ensures the leader CTA's barrier tracks byte arrivals from both CTAs. Only the leader CTA calls expect_bytes and wait.

Mask computations use PairBM (BM*2) so both CTAs make identical skip decisions. If one CTA skips a tile and the other doesn't, barriers desync.

Functions

Was this page helpful?