IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SubTileLoaderLDS_st_8x32

struct SubTileLoaderLDS_st_8x32[dtype: DType, BN: Int, depth: Int, BK: Int, num_threads: Int, v_full_v227: Bool = False]

DRAM→LDS DMA for the reference st_8x32_s SMEM layout (V operand).

Mirrors the reference's group-level cooperative load():

  • Each thread (laneid 0..63 across all num_threads / 64 warps) writes bytes_per_thread = 16 bytes per iteration directly to LDS at the natural byte offset lane_byte_offset = thread_id * 16 + iter * num_threads * 16.
  • thread_id is warp_id * WARP_SIZE + lane_id, so the LDS bytes cover successive subtiles in the reference's row-major-by-block-col ordering (subtile_id = subtile_row * subtiles_per_row + subtile_col, with subtile shape 8×BK BF16).
  • Each lane reads its source position in DRAM via the swizzle ↔ position bijection: subtile_lane_byte_offset → (row, col) within the 8×BK subtile, which then unpacks back into a (global_row, global_col) DRAM byte address. For the reference st_8x32 BF16 the swizzle is the identity, so the global position is just the natural subtile-local position.
  • Writes go via rocdl.raw.ptr.buffer.load.lds with the same _alias_scope_attr SubTileLoaderLDS uses, so consumer-side ds_read_tr* LDS reads tagged noalias_scopes=_alias_scope_attr can skip s_waitcnt vmcnt(0) (LLVM PR #74537's SIInsertWaitcnts vmcnt-relaxation handshake), provided the kernel maintains an explicit s_waitcnt vmcnt(0) + s_barrier fence at DMA/compute boundaries.

The layout is hard-coded to the reference's BF16 st_8x32_s:

  • subtile_rows = 8
  • subtile_cols = BK (32 for the V2 attention kernels)
  • No swizzle (st_8x32 BF16 returns the offset unchanged)

For the K operand (the reference uses st_32x32_s with a two-XOR swizzle), see SubTileLoaderLDS + swizzle/swizzle2 plumbing instead.

Parameters

  • dtype (DType): Element data type (must be BF16 — the st_8x32_s specialization assumes 2-byte elements; FP32 would use a different shape).
  • BN (Int): KV block height in elements (= 64 for the V2 attention kernels).
  • depth (Int): V tile column span in elements (= D for the model's head_dim; 64, 128, or 256 for the V2 attention kernels).
  • BK (Int): Subtile column span in elements (= 32 for the reference st_8x32_s).
  • num_threads (Int): Total threads in the cooperative load (= 8 warps × 64 lanes = 512 for the V2 attention kernels). Used to compute bytes_per_iter.
  • v_full_v227 (Bool): Reference v227 V LDS layout (Bool). Default False → byte-identical, the production st_8x32 contiguous fill. When True, the WRITE side of the reference V adapter: each cooperative-DMA 16-byte run (one key's 16 depth cols) is written to the LDS byte the reference v227 ds_read_b64_tr_b8 read expects, instead of the natural st_8x32 contiguous byte. The DRAM source (global_byte_in_tile) is UNCHANGED — only the LDS destination is remapped. The closed form (FP8 32×32×64, DEPTH=128, KV=128, with key = global_row 0..127 and depth = global_col 0..127) is lds_byte = c*0x410 + Lp*16 + depth, where c = (((key>>1)&1)<<1 | ((key>>2)&1)) + ((key>>3)&1)*4 + ((key>>6)&1)*8 and Lp = (key&1)*8 + ((key>>4)&1)*16 + ((key>>5)&1)*32. This is the W of the adapter W∘R pair. R is MhaMmaOp.precompute_v_lane_base[v_full_v227=True] (the v227 per-lane base) + load_V_frag[v_full_v227=True] (the faithful readout cell i_strip*0x2080 + j_depth*0x20 + r*0x100). The two compose to the IDENTITY fragment: W is derived as the byte permutation pi: ours_read_addr -> ref_read_addr over the slot, PROVEN a bijection that leaves the ds_read_tr8_b64 transpose invariant (both reads issue the identical tr8 op + 4-subread join, differing only in the LDS address, so the transpose cancels). The consumer MUST set v_full_v227=True too or V scrambles. The slot must hold ≥ 16624 B (max lds_byte 16623) — the _V_SLOT_PAD_ROWS (256 B / 4 rows) padding MlaPrefillV2 already allocates. Used ONLY by MlaPrefillV2 (the reference research kernel), where it is the default-on reference V LDS adapter. The production V2 MHA / MLA loaders build this type at v_full_v227=False (byte-identical).

Fields

  • bc (AMDBufferResource): The 128-bit buffer resource descriptor for DRAM access.

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

Methods

__init__

def __init__(gmem_tile: TileTensor[dtype, address_space=gmem_tile.address_space, linear_idx_type=gmem_tile.linear_idx_type, element_size=gmem_tile.element_size]) -> Self

Create a loader from a DRAM tile.

Args:

load

def load(self, v_smem_slot: TileTensor[dtype, address_space=AddressSpace.SHARED], v_gmem_tile: TileTensor[dtype, address_space=v_gmem_tile.address_space, linear_idx_type=v_gmem_tile.linear_idx_type, element_size=v_gmem_tile.element_size], warp_id_uniform: Int, lane_id_local: Int, scalar_offset: Int)

Cooperatively DMA one V tile from DRAM into LDS.

scalar_offset is the runtime-uniform byte offset of v_gmem_tile relative to the buffer-resource's base. Caller computes this once (typically Int(v_gmem_tile.ptr) - Int(self.bc.get_base_ptr())) and passes the value here. The method uses it as the rocdl.raw.ptr.buffer.load.lds scalar-offset argument.

For callers that construct the loader from the same v_gmem_tile they pass here, scalar_offset is 0 — both common production paths (MhaPrefillV2._dma_v and MlaPrefillV2Core._dma_v) hit this case.

Args: