Skip to main content

Mojo struct

MmaOp

struct MmaOp[in_type: DType, accum_type: DType, WM: Int, WN: Int, BK: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, alignment: Int, enable_swizzle: Bool, swizzle_elem_base: Int, swizzle_shift: Int]

Encapsulates MMA register tiles and operations for matrix multiplication.

This struct manages register tiles and MMA operations for a single warp. It processes warp-sized tiles (WM × BK for A, WN × BK for B) without knowledge of the broader kernel architecture.

MmaOp accepts generic SMemTileType and validates compatibility at compile-time via load_lds_fragment constraints.

Note: Several values are derived from other parameters:

  • num_m_mmas = WM // MMA_M
  • num_n_mmas = WN // MMA_N
  • num_k_mmas = BK // MMA_K
  • load_width = simd_width_ofin_type (SIMD width for input type)
  • accum_width = (MMA_M * MMA_N) // WARP_SIZE (elements per thread)

Quadrant Processing: The warp tile is divided into 4 quadrants for MMA scheduling:

  • quadrant_m_mmas = num_m_mmas // 2 (M-dimension quadrant size)
  • quadrant_n_mmas = num_n_mmas // 2 (N-dimension quadrant size) This enables efficient interleaving of loads and computes.

Thread Layout for MMA: AMD's expected pattern: 64 threads → 4 rows × 16 cols (row-major) Lane offset computed on-the-fly via lane_id()

Swizzle Configuration (enable_swizzle=True): MmaOp receives swizzle parameters from the kernel/TileBuffers, since they are determined by how data is loaded into LDS. MmaOp must read using the same swizzle pattern that was used for writing.

  • swizzle_elem_base: bit position for XOR (from loading subtile width)
  • swizzle_shift: XOR source distance (from loading subtile rows)

Fields

  • a_reg_tile (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].ARegTileType):
  • b_reg_tile (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].BRegTileType):
  • out_quadrants (StaticTuple[StaticTuple[LayoutTensor[accum_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].accum_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment], 2], 2]):

Implemented traits

AnyType, UnknownDestructibility

comptime members

__del__is_trivial

comptime __del__is_trivial = True

accum_width

comptime accum_width = ((MMA_M * MMA_N) // WARP_SIZE)

ARegTileType

comptime ARegTileType = LayoutTensor[in_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]

BRegTileType

comptime BRegTileType = LayoutTensor[in_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_n_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]

elem_swizzle

comptime elem_swizzle = OptionalReg[Swizzle](Swizzle(1, swizzle_elem_base, swizzle_shift)) if enable_swizzle else OptionalReg[Swizzle]()

lgkm_per_load_a

comptime lgkm_per_load_a = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_m_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas)

lgkm_per_load_ab

comptime lgkm_per_load_ab = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].lgkm_per_load_a + MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].lgkm_per_load_b)

lgkm_per_load_b

comptime lgkm_per_load_b = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas)

load_width

comptime load_width = simd_width_of[in_type]()

mma_access_layout

comptime mma_access_layout = Layout(IntTuple(16, 4), IntTuple((4 * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width), MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width))

num_k_mmas

comptime num_k_mmas = (BK // MMA_K)

num_m_mmas

comptime num_m_mmas = (WM // MMA_M)

num_n_mmas

comptime num_n_mmas = (WN // MMA_N)

OutQuadrantType

comptime OutQuadrantType = LayoutTensor[accum_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].accum_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]

quadrant_m_mmas

comptime quadrant_m_mmas = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_m_mmas // 2)

quadrant_n_mmas

comptime quadrant_n_mmas = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_n_mmas // 2)

RegTileType

comptime RegTileType[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major(num_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]

Parameters

  • num_mmas (Int):

Methods

__init__

__init__(out self)

Initialize MMA operation with register tiles.

reset_accumulator

reset_accumulator(self)

Reset all output quadrants to zero.

load_a

load_a[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])

Load A[which] from LDS → registers.

Accepts SMemTileType with matching dtype - layout compatibility validated at compile-time via load_lds_fragment constraints.

load_b

load_b[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])

Load B[which] from LDS → registers.

Accepts SMemTileType with matching dtype - layout compatibility validated at compile-time via load_lds_fragment constraints.

load_b_with_transpose

load_b_with_transpose[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])

Load B[which] from LDS → registers using hardware transpose.

Uses ds_read_tr16_b64 instruction for efficient transposed LDS read. This function expects B tiles in (N, K) storage order and produces data in the format expected by AMD MFMA instructions.

Supports swizzle: When enable_swizzle is True, applies the byte swizzle pattern to LDS read offsets for bank-conflict-free access.

Requires: MMA shape must be 16x16x32 or 32x32x16 (double-rate MFMA).

Args:

  • smem_tile (LayoutTensor): B tile in LDS with shape (mma_tile_n, BK) = (N, K) order.

mma

mma[which_a: Int, which_b: Int](self)

Execute MMA operations for a quadrant of the output tile.

Each quadrant is stored in a separate contiguous register tile.

Parameters:

  • which_a (Int): A quadrant index (0 or 1).
  • which_b (Int): B quadrant index (0 or 1).

Was this page helpful?