For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
TiledMmaLoader
struct TiledMmaLoader[in_type: DType, mma_shape: IndexList[3], swizzle: Optional[Swizzle] = Optional(), swizzle2: Optional[Swizzle] = Optional()]
SMEM→register loader expert for MFMA operand fragments.
Sibling to TiledMmaOp (static MFMA compute). Stateless — all
methods are @staticmethod. Parameterized by operand dtype, MMA
instruction shape, and optional vector-space swizzle. Reusable
wherever a kernel issues MMA-tile-shaped SMEM reads (attention's
QK / PV matmuls, potential future matmul variants).
Static methods:
load_b: full B-operand load from a warp-sized SMEM tile. M-outer iteration; handles BF16 single-load and FP8 two-half-load (num_packsbranch) with optional vector-space swizzle.load_b_tr: transposed single-MMA-tile load viads_read_tr16_b64_warphalves + join (BF16 double-rate shapes: 32x32x16 and 16x16x32).load_v_fp8_strip: FP8 V-operand per-strip load viads_read_tr8_b64with paired-lane addressing (16x16x128 FP8 PV matmul).
_load_b_tile is a private helper used by load_b.
Parameters
- in_type (
DType): Operand element type. - mma_shape (
IndexList[3]): MMA instruction shape [M, N, K]. - swizzle (
Optional[Swizzle]): Optional vector-space swizzle forload_b. - swizzle2 (
Optional[Swizzle]): Optional second vector-space swizzle, applied AFTERswizzle. Use to compose two-XOR swizzles (e.g., HK'sst_32x32bit5^=bit9+bit4^=bit10byte-level pair, which are not expressible as a single Swizzle).
Implemented traits
AnyType,
ImplicitlyDestructible
Methods
load_b
static load_b[num_mmas: Int, simd_width: Int, imm_offset_bytes: Int = 0](src: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=src.linear_idx_type, element_size=src.element_size]) -> InlineArray[SIMD[in_type, simd_width], num_mmas]
Full B operand load from a SMEM warp tile.
Loads all MMA tiles from a WN x BK SMEM warp tile and returns them as an InlineArray of SIMD fragments (one per MMA tile).
Parameters:
- num_mmas (
Int): Number of MMA tiles to load. - simd_width (
Int): SIMD vector width for the element type. - imm_offset_bytes (
Int): Comptime byte offset added to each ds_read via then(numeric immediate) constraint, bypassing the AMDGPU instruction selector for the address-fold step. See_load_from_lds[imm_offset_bytes]. Cost: per-reads_waitcnt lgkmcnt(0)serializes LDS reads.
Args:
- src (
TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=src.linear_idx_type, element_size=src.element_size]): A WN x BK TileTensor in shared memory.
Returns:
InlineArray[SIMD[in_type, simd_width], num_mmas]: An InlineArray of SIMD fragments, one per MMA tile.
load_b_tr
static load_b_tr(tile: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=tile.linear_idx_type, element_size=tile.element_size]) -> SIMD[in_type, 8]
Transposed B operand load for double-rate MFMA shapes.
Splits the tile along the K dimension into two halves and concatenates the results.
Args:
- tile (
TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=tile.linear_idx_type, element_size=tile.element_size]): A MMA_K x MMA_N TileTensor in shared memory.
Returns:
SIMD[in_type, 8]: A SIMD[in_type, 8] vector with both halves concatenated.
load_v_fp8_strip
static load_v_fp8_strip[BN: Int, BK: Int, bk_tile: Int, dt: Int](v_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED], rel_key: Int, hw_key_shift: Int, depth_base: Int) -> SIMD[in_type, 32]
FP8 V per-strip ds_read_tr8_b64 load for one (bk_tile, dt).
Paired-lane addressing: issues 4 ds_read_tr8_b64 calls (at
key_base = 0, 16, 32, 48) and joins the results into one
32-element SIMD matching the MFMA C-output column pattern for
the 16x16x128 FP8 V operand in the PV matmul.
Caller is responsible for precomputing the per-lane coords
(rel_key, hw_key_shift, depth_base) ONCE before the outer
(bk_tile, dt) loop — they're lane-only, not
(bk_tile, dt)-dependent, so hoisting saves redundant address
math per iteration.
Parameters:
- BN (
Int): V block height in elements. - BK (
Int): V block width in elements. - bk_tile (
Int): Which BK-tall row strip (0..num_k_tiles - 1). - dt (
Int): Which depth-tile within the strip (0..depth/MMA_M - 1).
Args:
- v_base (
UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED]): Pointer to V SMEM stage base (block 0 ofnum_repeats). - rel_key (
Int): Per-lane relative key index within the 16-lane row. - hw_key_shift (
Int): +4 for lanes in hw1, +0 for hw0. - depth_base (
Int): Per-lane depth sub-range offset (0 or 8 or 16 or 24).
Returns:
SIMD[in_type, 32]: SIMD[in_type, 32] for this lane's (bk_tile, dt) strip.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!