For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
SubTileLoaderLDS_st_8x32
struct SubTileLoaderLDS_st_8x32[dtype: DType, BN: Int, depth: Int, BK: Int, num_threads: Int, v_full_v227: Bool = False]
DRAM→LDS DMA for the reference st_8x32_s SMEM layout (V operand).
Mirrors the reference's group-level cooperative load():
- Each thread (laneid 0..63 across all
num_threads / 64warps) writesbytes_per_thread = 16bytes per iteration directly to LDS at the natural byte offsetlane_byte_offset = thread_id * 16 + iter * num_threads * 16. thread_idiswarp_id * WARP_SIZE + lane_id, so the LDS bytes cover successive subtiles in the reference's row-major-by-block-col ordering (subtile_id = subtile_row * subtiles_per_row + subtile_col, with subtile shape 8×BK BF16).- Each lane reads its source position in DRAM via the swizzle ↔
position bijection: subtile_lane_byte_offset → (row, col)
within the 8×BK subtile, which then unpacks back into a
(global_row, global_col) DRAM byte address. For the reference
st_8x32BF16 the swizzle is the identity, so the global position is just the natural subtile-local position. - Writes go via
rocdl.raw.ptr.buffer.load.ldswith the same_alias_scope_attrSubTileLoaderLDS uses, so consumer-sideds_read_tr*LDS reads taggednoalias_scopes=_alias_scope_attrcan skips_waitcnt vmcnt(0)(LLVM PR #74537'sSIInsertWaitcntsvmcnt-relaxation handshake), provided the kernel maintains an explicits_waitcnt vmcnt(0) + s_barrierfence at DMA/compute boundaries.
The layout is hard-coded to the reference's BF16 st_8x32_s:
subtile_rows = 8subtile_cols = BK(32 for the V2 attention kernels)- No swizzle (st_8x32 BF16 returns the offset unchanged)
For the K operand (the reference uses st_32x32_s with a two-XOR
swizzle), see SubTileLoaderLDS + swizzle/swizzle2 plumbing instead.
Parameters
- dtype (
DType): Element data type (must be BF16 — thest_8x32_sspecialization assumes 2-byte elements; FP32 would use a different shape). - BN (
Int): KV block height in elements (= 64 for the V2 attention kernels). - depth (
Int): V tile column span in elements (= D for the model's head_dim; 64, 128, or 256 for the V2 attention kernels). - BK (
Int): Subtile column span in elements (= 32 for the referencest_8x32_s). - num_threads (
Int): Total threads in the cooperative load (= 8 warps × 64 lanes = 512 for the V2 attention kernels). Used to computebytes_per_iter. - v_full_v227 (
Bool): Referencev227V LDS layout (Bool). Default False → byte-identical, the productionst_8x32contiguous fill. When True, the WRITE side of the reference V adapter: each cooperative-DMA 16-byte run (one key's 16 depth cols) is written to the LDS byte the referencev227ds_read_b64_tr_b8read expects, instead of the naturalst_8x32contiguous byte. The DRAM source (global_byte_in_tile) is UNCHANGED — only the LDS destination is remapped. The closed form (FP8 32×32×64, DEPTH=128, KV=128, withkey = global_row 0..127anddepth = global_col 0..127) islds_byte = c*0x410 + Lp*16 + depth, wherec = (((key>>1)&1)<<1 | ((key>>2)&1)) + ((key>>3)&1)*4 + ((key>>6)&1)*8andLp = (key&1)*8 + ((key>>4)&1)*16 + ((key>>5)&1)*32. This is theWof the adapterW∘Rpair.RisMhaMmaOp.precompute_v_lane_base[v_full_v227=True](thev227per-lane base) +load_V_frag[v_full_v227=True](the faithful readout celli_strip*0x2080 + j_depth*0x20 + r*0x100). The two compose to the IDENTITY fragment:Wis derived as the byte permutationpi: ours_read_addr -> ref_read_addrover the slot, PROVEN a bijection that leaves theds_read_tr8_b64transpose invariant (both reads issue the identical tr8 op + 4-subread join, differing only in the LDS address, so the transpose cancels). The consumer MUST setv_full_v227=Truetoo or V scrambles. The slot must hold ≥ 16624 B (maxlds_byte16623) — the_V_SLOT_PAD_ROWS(256 B / 4 rows) paddingMlaPrefillV2already allocates. Used ONLY byMlaPrefillV2(the reference research kernel), where it is the default-on reference V LDS adapter. The production V2 MHA / MLA loaders build this type atv_full_v227=False(byte-identical).
Fields
- bc (
AMDBufferResource): The 128-bit buffer resource descriptor for DRAM access.
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
Methods
__init__
def __init__(gmem_tile: TileTensor[dtype, address_space=gmem_tile.address_space, linear_idx_type=gmem_tile.linear_idx_type, element_size=gmem_tile.element_size]) -> Self
Create a loader from a DRAM tile.
Args:
- gmem_tile (
TileTensor[dtype, address_space=gmem_tile.address_space, linear_idx_type=gmem_tile.linear_idx_type, element_size=gmem_tile.element_size]): The full DRAM tile from KVCacheIterator (carries a Scalar valid_rows for clamping bounds).
load
def load(self, v_smem_slot: TileTensor[dtype, address_space=AddressSpace.SHARED], v_gmem_tile: TileTensor[dtype, address_space=v_gmem_tile.address_space, linear_idx_type=v_gmem_tile.linear_idx_type, element_size=v_gmem_tile.element_size], warp_id_uniform: Int, lane_id_local: Int, scalar_offset: Int)
Cooperatively DMA one V tile from DRAM into LDS.
scalar_offset is the runtime-uniform byte offset of
v_gmem_tile relative to the buffer-resource's base. Caller
computes this once (typically
Int(v_gmem_tile.ptr) - Int(self.bc.get_base_ptr())) and passes
the value here. The method uses it as the
rocdl.raw.ptr.buffer.load.lds scalar-offset argument.
For callers that construct the loader from the same
v_gmem_tile they pass here, scalar_offset is 0 — both
common production paths (MhaPrefillV2._dma_v and
MlaPrefillV2Core._dma_v) hit this case.
Args:
- v_smem_slot (
TileTensor[dtype, address_space=AddressSpace.SHARED]): Destination V SMEM tile (must hold at leastBN * depth * size_of[dtype]bytes; the loader uses.ptras the LDS base). - v_gmem_tile (
TileTensor[dtype, address_space=v_gmem_tile.address_space, linear_idx_type=v_gmem_tile.linear_idx_type, element_size=v_gmem_tile.element_size]): TileTensor view of the BN × depth source tile in DRAM. - warp_id_uniform (
Int): Wave-uniform warp index (0..num_warps-1). Caller must pass an SGPR-class value (e.g.,readfirstlane(warp_id())). - lane_id_local (
Int): Per-lane index (0..WARP_SIZE-1) fromlane_id(). - scalar_offset (
Int): Wave-uniform byte offset ofv_gmem_tilerelative to the buffer-resource base.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!