Mojo struct
MmaOp
struct MmaOp[in_type: DType, accum_type: DType, WM: Int, WN: Int, BK: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, alignment: Int, swizzle: OptionalReg[Swizzle]]
Encapsulates MMA register tiles and operations for matrix multiplication.
This struct manages register tiles and MMA operations for a single warp. It processes warp-sized tiles (WM × BK for A, WN × BK for B) without knowledge of the broader kernel architecture.
MmaOp accepts generic SMemTileType and validates compatibility at compile-time via load_lds_fragment constraints.
Note: Several values are derived from other parameters:
- num_m_mmas = WM // MMA_M
- num_n_mmas = WN // MMA_N
- num_k_mmas = BK // MMA_K
- load_width = simd_width_ofin_type (SIMD width for input type)
- accum_width = (MMA_M * MMA_N) // WARP_SIZE (elements per thread)
Quadrant Processing: The warp tile is divided into 4 quadrants for MMA scheduling:
- quadrant_m_mmas = num_m_mmas // 2 (M-dimension quadrant size)
- quadrant_n_mmas = num_n_mmas // 2 (N-dimension quadrant size) This enables efficient interleaving of loads and computes.
Thread Layout for MMA: AMD's expected pattern: 64 threads → 4 rows × 16 cols (row-major) Lane offset computed on-the-fly via lane_id()
Swizzle Configuration: MmaOp receives the swizzle pattern from the kernel/TileBuffers, since it's determined by how data is loaded into LDS. MmaOp must read using the same swizzle pattern that was used for writing.
- BF16: Swizzle(1, 5, 4) - 1 bit XOR
- FP8 16×128: Swizzle(3, 4, 4) - 3 bit XOR (HipKittens st_16x128)
Fields
- a_reg_tile (
MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].ARegTileType): - b_reg_tile (
MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].BRegTileType): - out_reg_tile (
MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].OutRegTileType):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
accum_width
comptime accum_width = ((MMA_M * MMA_N) // WARP_SIZE)
ARegTileType
comptime ARegTileType = LayoutTensor[in_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]
BRegTileType
comptime BRegTileType = LayoutTensor[in_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_n_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]
bytes_per_frag
comptime bytes_per_frag = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lds_frag_width * size_of[in_type]())
col_groups
comptime col_groups = (WARP_SIZE // MMA_M)
ds_reads_per_frag
comptime ds_reads_per_frag = ceildiv(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].bytes_per_frag, 16)
elem_swizzle
comptime elem_swizzle = swizzle
k_loads_per_mma
comptime k_loads_per_mma = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width // MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lds_frag_width)
lds_frag_width
comptime lds_frag_width = 16 if (eq MMA_K._mlir_value, 128) if (eq MMA_M._mlir_value, 16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (eq MMA_M._mlir_value, 16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width
lgkm_per_load_a
comptime lgkm_per_load_a = (((MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].quadrant_m_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas) * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].k_loads_per_mma) * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].ds_reads_per_frag)
lgkm_per_load_ab
comptime lgkm_per_load_ab = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lgkm_per_load_a + MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lgkm_per_load_b)
lgkm_per_load_b
comptime lgkm_per_load_b = (((MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas) * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].k_loads_per_mma) * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].ds_reads_per_frag)
load_width
comptime load_width = simd_width_of[in_type]()
mma_access_layout
comptime mma_access_layout = Layout(IntTuple(MMA_M, MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].col_groups), IntTuple(MMA_K, MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lds_frag_width))
mma_frag_width
comptime mma_frag_width = ((MMA_M * MMA_K) // WARP_SIZE)
num_k_mmas
comptime num_k_mmas = (BK // MMA_K)
num_m_mmas
comptime num_m_mmas = (WM // MMA_M)
num_n_mmas
comptime num_n_mmas = (WN // MMA_N)
out_reg_layout
comptime out_reg_layout = Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].accum_width))
OutRegTileType
comptime OutRegTileType = LayoutTensor[accum_type, MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].out_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]
quadrant_m_mmas
comptime quadrant_m_mmas = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_m_mmas // 2)
quadrant_m_size
comptime quadrant_m_size = MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].quadrant_m_mmas
quadrant_n_mmas
comptime quadrant_n_mmas = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_n_mmas // 2)
quadrant_n_size
comptime quadrant_n_size = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].accum_width)
RegTileType
comptime RegTileType[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major(num_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]
Parameters
- num_mmas (
Int):
use_fp8_16x16x128_mma
comptime use_fp8_16x16x128_mma = (MMA_K == 128) if (eq MMA_M._mlir_value, 16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (MMA_M == 16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (in_type == DType.float8_e4m3fn)
use_fp8_32x32x64_mma
comptime use_fp8_32x32x64_mma = (MMA_K == 64) if (eq MMA_M._mlir_value, 32) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (MMA_M == 32) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (in_type == DType.float8_e4m3fn)
Methods
__init__
__init__(out self)
Initialize MMA operation with register tiles.
reset_accumulator
reset_accumulator(self)
Reset output register tile to zero.
load_a
load_a[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])
Load A[which] from LDS → registers.
Accepts SMemTileType with matching dtype - layout compatibility validated at compile-time via load_lds_fragment constraints.
For FP8 16×16×128: Uses lds_frag_width=16 with 2 K-iterations per MMA. For FP8 32×32×64: Uses lds_frag_width=32 with single load. For BF16 16×16×32: Uses lds_frag_width=8.
load_b
load_b[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])
Load B[which] from LDS → registers.
Accepts SMemTileType with matching dtype - layout compatibility validated at compile-time via load_lds_fragment constraints.
For FP8 16×16×128: Uses lds_frag_width=16 with 2 K-iterations per MMA. For FP8 32×32×64: Uses lds_frag_width=32 with single load. For BF16 16×16×32: Uses lds_frag_width=8.
mma
mma[which_a: Int, which_b: Int](self)
Execute MMA operations for a quadrant of the output tile.
Accesses quadrant via .tile[] view into the contiguous out_reg_tile. Uses mma_frag_width for fragment sizing (4 for BF16, 8 for FP8).
Works for both BF16 and FP8 via stdlib mma() dispatch.
Parameters:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!