Skip to main content

Mojo struct

MmaOp

struct MmaOp[in_type: DType, accum_type: DType, WM: Int, WN: Int, BK: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, alignment: Int, swizzle: OptionalReg[Swizzle]]

Encapsulates MMA register tiles and operations for matrix multiplication.

This struct manages register tiles and MMA operations for a single warp. It processes warp-sized tiles (WM × BK for A, WN × BK for B) without knowledge of the broader kernel architecture.

MmaOp accepts generic SMemTileType and validates compatibility at compile-time via load_lds_fragment constraints.

Note: Several values are derived from other parameters:

  • num_m_mmas = WM // MMA_M
  • num_n_mmas = WN // MMA_N
  • num_k_mmas = BK // MMA_K
  • load_width = simd_width_ofin_type (SIMD width for input type)
  • accum_width = (MMA_M * MMA_N) // WARP_SIZE (elements per thread)

Quadrant Processing: The warp tile is divided into 4 quadrants for MMA scheduling:

  • quadrant_m_mmas = num_m_mmas // 2 (M-dimension quadrant size)
  • quadrant_n_mmas = num_n_mmas // 2 (N-dimension quadrant size) This enables efficient interleaving of loads and computes.

Thread Layout for MMA: AMD's expected pattern: 64 threads → 4 rows × 16 cols (row-major) Lane offset computed on-the-fly via lane_id()

Swizzle Configuration: MmaOp receives the swizzle pattern from the kernel/TileBuffers, since it's determined by how data is loaded into LDS. MmaOp must read using the same swizzle pattern that was used for writing.

  • BF16: Swizzle(1, 5, 4) - 1 bit XOR
  • FP8 16×128: Swizzle(3, 4, 4) - 3 bit XOR (HipKittens st_16x128)

Fields

  • a_reg_tile (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].ARegTileType):
  • b_reg_tile (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].BRegTileType):
  • out_reg_tile (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].OutRegTileType):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

accum_width

comptime accum_width = ((MMA_M * MMA_N) // WARP_SIZE)

ARegTileType

comptime ARegTileType = LayoutTensor[in_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]

BRegTileType

comptime BRegTileType = LayoutTensor[in_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_n_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]

bytes_per_frag

comptime bytes_per_frag = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lds_frag_width * size_of[in_type]())

col_groups

comptime col_groups = (WARP_SIZE // MMA_M)

ds_reads_per_frag

comptime ds_reads_per_frag = ceildiv(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].bytes_per_frag, 16)

elem_swizzle

comptime elem_swizzle = swizzle

k_loads_per_mma

comptime k_loads_per_mma = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width // MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lds_frag_width)

lds_frag_width

comptime lds_frag_width = 16 if (eq MMA_K._mlir_value, 128) if (eq MMA_M._mlir_value, 16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (eq MMA_M._mlir_value, 16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width

lgkm_per_load_a

comptime lgkm_per_load_a = (((MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].quadrant_m_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas) * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].k_loads_per_mma) * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].ds_reads_per_frag)

lgkm_per_load_ab

comptime lgkm_per_load_ab = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lgkm_per_load_a + MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lgkm_per_load_b)

lgkm_per_load_b

comptime lgkm_per_load_b = (((MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas) * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].k_loads_per_mma) * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].ds_reads_per_frag)

load_width

comptime load_width = simd_width_of[in_type]()

mma_access_layout

comptime mma_access_layout = Layout(IntTuple(MMA_M, MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].col_groups), IntTuple(MMA_K, MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].lds_frag_width))

mma_frag_width

comptime mma_frag_width = ((MMA_M * MMA_K) // WARP_SIZE)

num_k_mmas

comptime num_k_mmas = (BK // MMA_K)

num_m_mmas

comptime num_m_mmas = (WM // MMA_M)

num_n_mmas

comptime num_n_mmas = (WN // MMA_N)

out_reg_layout

comptime out_reg_layout = Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].accum_width))

OutRegTileType

comptime OutRegTileType = LayoutTensor[accum_type, MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].out_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]

quadrant_m_mmas

comptime quadrant_m_mmas = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_m_mmas // 2)

quadrant_m_size

comptime quadrant_m_size = MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].quadrant_m_mmas

quadrant_n_mmas

comptime quadrant_n_mmas = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_n_mmas // 2)

quadrant_n_size

comptime quadrant_n_size = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].accum_width)

RegTileType

comptime RegTileType[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major(num_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, swizzle].mma_frag_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]

Parameters

  • num_mmas (Int):

use_fp8_16x16x128_mma

comptime use_fp8_16x16x128_mma = (MMA_K == 128) if (eq MMA_M._mlir_value, 16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (MMA_M == 16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (in_type == DType.float8_e4m3fn)

use_fp8_32x32x64_mma

comptime use_fp8_32x32x64_mma = (MMA_K == 64) if (eq MMA_M._mlir_value, 32) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (MMA_M == 32) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (in_type == DType.float8_e4m3fn)

Methods

__init__

__init__(out self)

Initialize MMA operation with register tiles.

reset_accumulator

reset_accumulator(self)

Reset output register tile to zero.

load_a

load_a[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])

Load A[which] from LDS → registers.

Accepts SMemTileType with matching dtype - layout compatibility validated at compile-time via load_lds_fragment constraints.

For FP8 16×16×128: Uses lds_frag_width=16 with 2 K-iterations per MMA. For FP8 32×32×64: Uses lds_frag_width=32 with single load. For BF16 16×16×32: Uses lds_frag_width=8.

load_b

load_b[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])

Load B[which] from LDS → registers.

Accepts SMemTileType with matching dtype - layout compatibility validated at compile-time via load_lds_fragment constraints.

For FP8 16×16×128: Uses lds_frag_width=16 with 2 K-iterations per MMA. For FP8 32×32×64: Uses lds_frag_width=32 with single load. For BF16 16×16×32: Uses lds_frag_width=8.

mma

mma[which_a: Int, which_b: Int](self)

Execute MMA operations for a quadrant of the output tile.

Accesses quadrant via .tile[] view into the contiguous out_reg_tile. Uses mma_frag_width for fragment sizing (4 for BF16, 8 for FP8).

Works for both BF16 and FP8 via stdlib mma() dispatch.

Parameters:

  • which_a (Int): A quadrant index (0 or 1).
  • which_b (Int): B quadrant index (0 or 1).

Was this page helpful?