IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MmaOp

struct MmaOp[out_type: DType, in_type: DType, shape: IndexList[3], k_group_size: Int, num_k_tiles: Int, num_m_mmas: Int, num_n_mmas: Int, swizzle: Optional[Swizzle] = None]

Register ownership + SMEM loading + schedule API for AMD matmul.

Owns A/B/C register tiles in LOCAL address space. Provides the schedule-facing API: load_frag[k] loads from SMEM to registers, mma[k] delegates to TiledMma for computation.

Parameters​

  • ​out_type (DType): Accumulator data type (typically float32).
  • ​in_type (DType): Input element data type (bfloat16 or float8).
  • ​shape (IndexList[3]): MMA instruction shape [M, N, K].
  • ​k_group_size (Int): Number of MMA k-steps per fragment load.
  • ​num_k_tiles (Int): Number of k-tiles across the warp K dimension.
  • ​num_m_mmas (Int): MMA tiles along M within the warp tile.
  • ​num_n_mmas (Int): MMA tiles along N within the warp tile.
  • ​swizzle (Optional[Swizzle]): Optional SMEM swizzle for fragment loading.

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

c_frag_size​

comptime c_frag_size = ((MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_M * MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_N) // WARP_SIZE)

k_tile_size​

comptime k_tile_size = (MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_K * k_group_size)

MMA_K​

comptime MMA_K = shape[2]

MMA_M​

comptime MMA_M = shape[0]

MMA_N​

comptime MMA_N = shape[1]

simd_width​

comptime simd_width = (k_group_size * ((MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_M * MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_K) // WARP_SIZE))

WM​

comptime WM = (num_m_mmas * MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_M)

WN​

comptime WN = (num_n_mmas * MmaOp[out_type, in_type, shape, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, swizzle].MMA_N)

Methods​

__init__​

def __init__(out self)

accum_tile​

def accum_tile(self) -> ref[self._c_reg] TileTensor[out_type, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Returns:

ref[self._c_reg] TileTensor[out_type, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

load_frag​

def load_frag[k_tile_idx: Int](self, a_smem_warp: TileTensor[in_type, address_space=AddressSpace.SHARED], b_smem_warp: TileTensor[in_type, address_space=AddressSpace.SHARED])

Load A and B MMA fragments for k-tile k_tile_idx from SMEM.

Expects block-local warp tiles of shape WM x k_tile_size (or WN x k_tile_size), where each k-tile block is contiguous in SMEM (blocked_product layout). Uses direct distribute with swizzle β€” correct because each block starts at a swizzle-aligned offset.

mma​

def mma[k_tile_idx: Int](self)

Execute MMA for k-tile k_tile_idx via TiledMma.

Slices A/B registers for this k-tile and delegates to TiledMma.mma for stateless computation.