IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TiledMmaLoader

struct TiledMmaLoader[in_type: DType, mma_shape: IndexList[3], swizzle: Optional[Swizzle] = Optional(), swizzle2: Optional[Swizzle] = Optional()]

SMEM→register loader expert for MFMA operand fragments.

Sibling to TiledMmaOp (static MFMA compute). Stateless — all methods are @staticmethod. Parameterized by operand dtype, MMA instruction shape, and optional vector-space swizzle. Reusable wherever a kernel issues MMA-tile-shaped SMEM reads (attention's QK / PV matmuls, potential future matmul variants).

Static methods:

  • load_b: full B-operand load from a warp-sized SMEM tile. M-outer iteration; handles BF16 single-load and FP8 two-half-load (num_packs branch) with optional vector-space swizzle.
  • load_b_tr: transposed single-MMA-tile load via ds_read_tr16_b64_warp halves + join (BF16 double-rate shapes: 32x32x16 and 16x16x32).
  • load_v_fp8_strip: FP8 V-operand per-strip load via ds_read_tr8_b64 with paired-lane addressing (16x16x128 FP8 PV matmul).

_load_b_tile is a private helper used by load_b.

Parameters

  • in_type (DType): Operand element type.
  • mma_shape (IndexList[3]): MMA instruction shape [M, N, K].
  • swizzle (Optional[Swizzle]): Optional vector-space swizzle for load_b.
  • swizzle2 (Optional[Swizzle]): Optional second vector-space swizzle, applied AFTER swizzle. Use to compose two-XOR swizzles (e.g., HK's st_32x32 bit5^=bit9 + bit4^=bit10 byte-level pair, which are not expressible as a single Swizzle).

Implemented traits

AnyType, ImplicitlyDestructible

Methods

load_b

static load_b[num_mmas: Int, simd_width: Int, imm_offset_bytes: Int = 0](src: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=src.linear_idx_type, element_size=src.element_size]) -> InlineArray[SIMD[in_type, simd_width], num_mmas]

Full B operand load from a SMEM warp tile.

Loads all MMA tiles from a WN x BK SMEM warp tile and returns them as an InlineArray of SIMD fragments (one per MMA tile).

Parameters:

  • num_mmas (Int): Number of MMA tiles to load.
  • simd_width (Int): SIMD vector width for the element type.
  • imm_offset_bytes (Int): Comptime byte offset added to each ds_read via the n (numeric immediate) constraint, bypassing the AMDGPU instruction selector for the address-fold step. See _load_from_lds[imm_offset_bytes]. Cost: per-read s_waitcnt lgkmcnt(0) serializes LDS reads.

Args:

Returns:

InlineArray[SIMD[in_type, simd_width], num_mmas]: An InlineArray of SIMD fragments, one per MMA tile.

load_b_tr

static load_b_tr(tile: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=tile.linear_idx_type, element_size=tile.element_size]) -> SIMD[in_type, 8]

Transposed B operand load for double-rate MFMA shapes.

Splits the tile along the K dimension into two halves and concatenates the results.

Args:

Returns:

SIMD[in_type, 8]: A SIMD[in_type, 8] vector with both halves concatenated.

load_v_fp8_strip

static load_v_fp8_strip[BN: Int, BK: Int, bk_tile: Int, dt: Int](v_base: UnsafePointer[Scalar[in_type], MutAnyOrigin, address_space=AddressSpace.SHARED], rel_key: Int, hw_key_shift: Int, depth_base: Int) -> SIMD[in_type, 32]

FP8 V per-strip ds_read_tr8_b64 load for one (bk_tile, dt).

Paired-lane addressing: issues 4 ds_read_tr8_b64 calls (at key_base = 0, 16, 32, 48) and joins the results into one 32-element SIMD matching the MFMA C-output column pattern for the 16x16x128 FP8 V operand in the PV matmul.

Caller is responsible for precomputing the per-lane coords (rel_key, hw_key_shift, depth_base) ONCE before the outer (bk_tile, dt) loop — they're lane-only, not (bk_tile, dt)-dependent, so hoisting saves redundant address math per iteration.

Parameters:

  • BN (Int): V block height in elements.
  • BK (Int): V block width in elements.
  • bk_tile (Int): Which BK-tall row strip (0..num_k_tiles - 1).
  • dt (Int): Which depth-tile within the strip (0..depth/MMA_M - 1).

Args:

Returns:

SIMD[in_type, 32]: SIMD[in_type, 32] for this lane's (bk_tile, dt) strip.