Skip to main content

Mojo struct

MmaOp

struct MmaOp[out_type: DType, in_type: DType, shape: IndexList[3], k_group_size: Int, num_k_tiles: Int, num_m_mmas: Int, num_n_mmas: Int, swizzle: Optional[Swizzle] = None]

Register ownership + SMEM loading + schedule API for AMD matmul.

Owns A/B/C register tiles in LOCAL address space. Provides the schedule-facing API: load_frag[k] loads from SMEM to registers, mma[k] delegates to TiledMma for computation.

Parameters​

  • ​out_type (DType): Accumulator data type (typically float32).
  • ​in_type (DType): Input element data type (bfloat16 or float8).
  • ​shape (IndexList[3]): MMA instruction shape [M, N, K].
  • ​k_group_size (Int): Number of MMA k-steps per fragment load.
  • ​num_k_tiles (Int): Number of k-tiles across the warp K dimension.
  • ​num_m_mmas (Int): MMA tiles along M within the warp tile.
  • ​num_n_mmas (Int): MMA tiles along N within the warp tile.
  • ​swizzle (Optional[Swizzle]): Optional SMEM swizzle for fragment loading.

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

c_frag_size​

comptime c_frag_size = ((MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_M * MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_N) // WARP_SIZE)

k_tile_size​

comptime k_tile_size = (MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_K * k_group_size)

MMA_K​

comptime MMA_K = shape[2]

MMA_M​

comptime MMA_M = shape[0]

MMA_N​

comptime MMA_N = shape[1]

simd_width​

comptime simd_width = (k_group_size * ((MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_M * MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_K) // WARP_SIZE))

WM​

comptime WM = (num_m_mmas * MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_M)

WN​

comptime WN = (num_n_mmas * MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_N)

Methods​

__init__​

__init__(out self)

accum_tile​

accum_tile(self) -> ref[self._c_reg] TileTensor[out_type, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Returns:

ref[self._c_reg] TileTensor[out_type, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

load_frag​

load_frag[k_tile_idx: Int](self, a_smem_warp: TileTensor[in_type, address_space=AddressSpace.SHARED], b_smem_warp: TileTensor[in_type, address_space=AddressSpace.SHARED])

Load A and B MMA fragments for k-tile k_tile_idx from SMEM.

Expects block-local warp tiles of shape WM x k_tile_size (or WN x k_tile_size), where each k-tile block is contiguous in SMEM (blocked_product layout). Uses direct distribute with swizzle β€” correct because each block starts at a swizzle-aligned offset.

mma​

mma[k_tile_idx: Int](self)

Execute MMA for k-tile k_tile_idx via TiledMma.

Slices A/B registers for this k-tile and delegates to TiledMma.mma for stateless computation.