Mojo module
load_warp
TMA load warp logic for depth=256/512 pair-CTA SM100 attention.
Each CTA in the pair loads its own half of K/V data into its local SMEM. The pair-CTA MMA instruction reads from both SMs' SMEM to combine the halves.
K is split along BN rows: even CTA (is_leader=True) loads K[0:BN//2, :],
odd CTA (is_leader=False) loads K[BN//2:BN, :].
V loading depends on split_o (depth-dependent): split_o=True (depth=512): V split into V_lo and V_hi (separate pipeline slots): V_lo: even loads V[:, 0:ov_depth//4], odd loads V[:, ov_depth//4:ov_depth//2] V_hi: even loads V[:, ov_depth//2:3ov_depth//4], odd loads V[:, 3ov_depth//4:ov_depth] split_o=False (depth=256): Single V (no V_hi): V: even loads V[:, 0:ov_depth//2], odd loads V[:, ov_depth//2:ov_depth]
Q is per-CTA: even loads Q[0:BM, :], odd loads Q[BM:PairBM, :].
All TMA loads are non-multicast. cta_group=2 on the K/V TMAs (inside
tma_copy_k / tma_copy_v) tells the shared cluster barrier to
accumulate bytes from both CTAs. Only the leader CTA calls
expect_bytes and waits.
Mask computations use PairBM (BM*2) so both CTAs make identical skip decisions. If one CTA skips a tile and the other doesn't, barriers desync.
Functions
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!