For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
TiledMmaLoader
struct TiledMmaLoader[in_type: DType, mma_shape: IndexList[3], swizzle: Optional[Swizzle] = Optional(), swizzle2: Optional[Swizzle] = Optional()]
SMEMβregister loader expert for MFMA operand fragments.
Sibling to TiledMmaOp (static MFMA compute). Stateless β all
methods are @staticmethod. Parameterized by operand dtype, MMA
instruction shape, and optional vector-space swizzle. Reusable
wherever a kernel issues MMA-tile-shaped SMEM reads (attention's
QK / PV matmuls, potential future matmul variants).
Static methods:
load_b: full B-operand load from a warp-sized SMEM tile. M-outer iteration; handles BF16 single-load and FP8 two-half-load (num_packsbranch) with optional vector-space swizzle.load_b_tr: transposed single-MMA-tile load viads_read_tr16_b64_warphalves + join (BF16 double-rate shapes: 32x32x16 and 16x16x32).load_v_fp8_strip: FP8 V-operand per-strip load viads_read_tr8_b64with paired-lane addressing (16x16x128 FP8 PV matmul).
_load_b_tile is a private helper used by load_b.
Parametersβ
- βin_type (
DType): Operand element type. - βmma_shape (
IndexList[3]): MMA instruction shape [M, N, K]. - βswizzle (
Optional[Swizzle]): Optional vector-space swizzle forload_b. - βswizzle2 (
Optional[Swizzle]): Optional second vector-space swizzle, applied AFTERswizzle. Use to compose two-XOR swizzles (e.g., the referencest_32x32bit5^=bit9+bit4^=bit10byte-level pair, which are not expressible as a single Swizzle).
Implemented traitsβ
Methodsβ
load_bβ
static def load_b[num_mmas: Int, simd_width: Int, imm_offset_bytes: Int = 0](src: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=src.linear_idx_type, element_size=src.element_size]) -> InlineArray[SIMD[in_type, simd_width], num_mmas]
Full B operand load from a SMEM warp tile.
Loads all MMA tiles from a WN x BK SMEM warp tile and returns them as an InlineArray of SIMD fragments (one per MMA tile).
Parameters:
- βnum_mmas (
Int): Number of MMA tiles to load. - βsimd_width (
Int): SIMD vector width for the element type. - βimm_offset_bytes (
Int): Comptime byte offset added to each ds_read via then(numeric immediate) constraint, bypassing the AMDGPU instruction selector for the address-fold step. See_load_from_lds[imm_offset_bytes]. Cost: per-reads_waitcnt lgkmcnt(0)serializes LDS reads.
Args:
- βsrc (
TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=src.linear_idx_type, element_size=src.element_size]): A WN x BK TileTensor in shared memory.
Returns:
InlineArray[SIMD[in_type, simd_width], num_mmas]: An InlineArray of SIMD fragments, one per MMA tile.
load_b_trβ
static def load_b_tr(tile: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=tile.linear_idx_type, element_size=tile.element_size]) -> SIMD[in_type, 8]
Transposed B operand load for double-rate MFMA shapes.
Splits the tile along the K dimension into two halves and concatenates the results.
Args:
- βtile (
TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=tile.linear_idx_type, element_size=tile.element_size]): A MMA_K x MMA_N TileTensor in shared memory.
Returns:
SIMD[in_type, 8]: A SIMD[in_type, 8] vector with both halves concatenated.
load_v_fp8_stripβ
static def load_v_fp8_strip[BN: Int, BK: Int, bk_tile: Int, dt: Int](v_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED], rel_key: Int, hw_key_shift: Int, depth_base: Int) -> SIMD[in_type, 32]
FP8 V per-strip ds_read_tr8_b64 load for one (bk_tile, dt).
Paired-lane addressing: issues 4 ds_read_tr8_b64 calls (at
key_base = 0, 16, 32, 48) and joins the results into one
32-element SIMD matching the MFMA C-output column pattern for
the 16x16x128 FP8 V operand in the PV matmul.
Caller is responsible for precomputing the per-lane coords
(rel_key, hw_key_shift, depth_base) ONCE before the outer
(bk_tile, dt) loop β they're lane-only, not
(bk_tile, dt)-dependent, so hoisting saves redundant address
math per iteration.
Parameters:
- βBN (
Int): V block height in elements. - βBK (
Int): V block width in elements. - βbk_tile (
Int): Which BK-tall row strip (0..num_k_tiles - 1). - βdt (
Int): Which depth-tile within the strip (0..depth/MMA_M - 1).
Args:
- βv_base (
UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED]): Pointer to V SMEM stage base (block 0 ofnum_repeats). - βrel_key (
Int): Per-lane relative key index within the 16-lane row. - βhw_key_shift (
Int): +4 for lanes in hw1, +0 for hw0. - βdepth_base (
Int): Per-lane depth sub-range offset (0 or 8 or 16 or 24).
Returns:
SIMD[in_type, 32]: SIMD[in_type, 32] for this lane's (bk_tile, dt) strip.
load_v_fp8_strip_16β
static def load_v_fp8_strip_16[BN: Int, block_width: Int, bk_tile: Int, dt: Int](v_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED], key_group: Int, pair_idx: Int, is_odd: Int) -> SIMD[in_type, 32]
FP8 V per-strip ds_read_tr8_b64 load for one (bk_tile, dt), sized for the 16x16x128 MFMA A-operand fragment layout.
Lane partition for a 64-lane wave (lane id l):
- key_group g = l // 16 (0..3 β 16-lane "rows")
- pair_idx p = (l % 16)/2 (0..7 β pair within the row)
- is_odd o = l % 2 (0 or 1)
Per (bk_tile, dt), one MFMA tile of V is 16 depths * 128 keys
= 2048 FP8 = 64 lanes * 32 FP8/lane. Four ds_read_tr8_b64
calls at key_base β {0, 8, 16, 24} deliver 8 keys per lane
each, totaling 32 contiguous keys per lane at one depth.
Within one ds_read_tr8_b64 call, each 16-lane row performs
two interleaved 8x8 byte transposes (one over the 8 even
lanes, one over the 8 odd lanes). Per the AMD ISA, paired
even/odd lanes share a key and read 8 depths each:
- Even lane (p, o=0) reads V[g*32 + key_base + p, depth 0..7]
- Odd lane (p, o=1) reads V[g*32 + key_base + p, depth 8..15] After the transpose, even lane at pair_idx p in the row gets 8 keys (key_base..key_base+7 within the group) at depth p; odd lane at pair_idx p gets the same 8 keys at depth p+8.
The per-lane output matches the scalar gather it replaces:
lane l holds V[key=g32..g32+31, depth=butterfly(l%16) + dt*16],
where butterfly(p) = (p/2) + (p%2)*8. The depth axis is a
butterfly permutation of the linear ordering β the MFMA's
A-operand lane->m_h mapping for the 16x16x128 shape consumes
this permuted layout directly (no post-load permute needed).
Parameters:
- βBN (
Int): V block height in elements (keys per block). - βblock_width (
Int): SMEM block width in depth elements (caller'sbk_smemβ notBKif the K-split path is active andbk_smem < BK). - βbk_tile (
Int): Which BK-tall row strip (always 0 here; kept for API symmetry with the 32x32x64 variant). - βdt (
Int): Which depth-tile within the strip (0 .. depth/MMA_M - 1).
Args:
- βv_base (
UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED]): Pointer to V SMEM stage base (block 0). - βkey_group (
Int): Per-lane key-group index (lane_id // 16). - βpair_idx (
Int): Per-lane pair index ((lane_id % 16) // 2). - βis_odd (
Int): Per-lane parity (lane_id % 2).
Returns:
SIMD[in_type, 32]: SIMD[in_type, 32] β 32 contiguous keys at one depth for
this lane's (bk_tile, dt).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!