Mojo struct
MmaOp
struct MmaOp[in_type: DType, accum_type: DType, WM: Int, WN: Int, BK: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, alignment: Int, enable_swizzle: Bool, swizzle_elem_base: Int, swizzle_shift: Int]
Encapsulates MMA register tiles and operations for matrix multiplication.
This struct manages register tiles and MMA operations for a single warp. It processes warp-sized tiles (WM × BK for A, WN × BK for B) without knowledge of the broader kernel architecture.
MmaOp accepts generic SMemTileType and validates compatibility at compile-time via load_lds_fragment constraints.
Note: Several values are derived from other parameters:
- num_m_mmas = WM // MMA_M
- num_n_mmas = WN // MMA_N
- num_k_mmas = BK // MMA_K
- load_width = simd_width_ofin_type (SIMD width for input type)
- accum_width = (MMA_M * MMA_N) // WARP_SIZE (elements per thread)
Quadrant Processing: The warp tile is divided into 4 quadrants for MMA scheduling:
- quadrant_m_mmas = num_m_mmas // 2 (M-dimension quadrant size)
- quadrant_n_mmas = num_n_mmas // 2 (N-dimension quadrant size) This enables efficient interleaving of loads and computes.
Thread Layout for MMA: AMD's expected pattern: 64 threads → 4 rows × 16 cols (row-major) Lane offset computed on-the-fly via lane_id()
Swizzle Configuration (enable_swizzle=True): MmaOp receives swizzle parameters from the kernel/TileBuffers, since they are determined by how data is loaded into LDS. MmaOp must read using the same swizzle pattern that was used for writing.
- swizzle_elem_base: bit position for XOR (from loading subtile width)
- swizzle_shift: XOR source distance (from loading subtile rows)
Fields
- a_reg_tile (
MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].ARegTileType): - b_reg_tile (
MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].BRegTileType): - out_quadrants (
StaticTuple[StaticTuple[LayoutTensor[accum_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].accum_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment], 2], 2]):
Implemented traits
AnyType,
UnknownDestructibility
comptime members
__del__is_trivial
comptime __del__is_trivial = True
accum_width
comptime accum_width = ((MMA_M * MMA_N) // WARP_SIZE)
ARegTileType
comptime ARegTileType = LayoutTensor[in_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]
BRegTileType
comptime BRegTileType = LayoutTensor[in_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_n_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]
elem_swizzle
comptime elem_swizzle = OptionalReg[Swizzle](Swizzle(1, swizzle_elem_base, swizzle_shift)) if enable_swizzle else OptionalReg[Swizzle]()
lgkm_per_load_a
comptime lgkm_per_load_a = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_m_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas)
lgkm_per_load_ab
comptime lgkm_per_load_ab = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].lgkm_per_load_a + MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].lgkm_per_load_b)
lgkm_per_load_b
comptime lgkm_per_load_b = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas)
load_width
comptime load_width = simd_width_of[in_type]()
mma_access_layout
comptime mma_access_layout = Layout(IntTuple(16, 4), IntTuple((4 * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width), MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width))
num_k_mmas
comptime num_k_mmas = (BK // MMA_K)
num_m_mmas
comptime num_m_mmas = (WM // MMA_M)
num_n_mmas
comptime num_n_mmas = (WN // MMA_N)
OutQuadrantType
comptime OutQuadrantType = LayoutTensor[accum_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].accum_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]
quadrant_m_mmas
comptime quadrant_m_mmas = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_m_mmas // 2)
quadrant_n_mmas
comptime quadrant_n_mmas = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_n_mmas // 2)
RegTileType
comptime RegTileType[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major(num_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]
Parameters
- num_mmas (
Int):
Methods
__init__
__init__(out self)
Initialize MMA operation with register tiles.
reset_accumulator
reset_accumulator(self)
Reset all output quadrants to zero.
load_a
load_a[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])
Load A[which] from LDS → registers.
Accepts SMemTileType with matching dtype - layout compatibility validated at compile-time via load_lds_fragment constraints.
load_b
load_b[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])
Load B[which] from LDS → registers.
Accepts SMemTileType with matching dtype - layout compatibility validated at compile-time via load_lds_fragment constraints.
load_b_with_transpose
load_b_with_transpose[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])
Load B[which] from LDS → registers using hardware transpose.
Uses ds_read_tr16_b64 instruction for efficient transposed LDS read. This function expects B tiles in (N, K) storage order and produces data in the format expected by AMD MFMA instructions.
Supports swizzle: When enable_swizzle is True, applies the byte swizzle pattern to LDS read offsets for bank-conflict-free access.
Requires: MMA shape must be 16x16x32 or 32x32x16 (double-rate MFMA).
Args:
- smem_tile (
LayoutTensor): B tile in LDS with shape (mma_tile_n, BK) = (N, K) order.
mma
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!