IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

load_warp

TMA load warp logic for depth=256/512 pair-CTA SM100 attention.

Each CTA in the pair loads its own half of K/V data into its local SMEM. The pair-CTA MMA instruction reads from both SMs' SMEM to combine the halves.

K is split along BN rows: even CTA (is_leader=True) loads K[0:BN//2, :], odd CTA (is_leader=False) loads K[BN//2:BN, :].

V loading depends on split_o (depth-dependent): split_o=True (depth=512): V split into V_lo and V_hi (separate pipeline slots): V_lo: even loads V[:, 0:ov_depth//4], odd loads V[:, ov_depth//4:ov_depth//2] V_hi: even loads V[:, ov_depth//2:3ov_depth//4], odd loads V[:, 3ov_depth//4:ov_depth] split_o=False (depth=256): Single V (no V_hi): V: even loads V[:, 0:ov_depth//2], odd loads V[:, ov_depth//2:ov_depth]

Q is per-CTA: even loads Q[0:BM, :], odd loads Q[BM:PairBM, :].

All TMA loads are non-multicast. cta_group=2 on the K/V TMAs (inside tma_copy_k / tma_copy_v) tells the shared cluster barrier to accumulate bytes from both CTAs. Only the leader CTA calls expect_bytes and waits.

Mask computations use PairBM (BM*2) so both CTAs make identical skip decisions. If one CTA skips a tile and the other doesn't, barriers desync.

Functions