Skip to main content

Mojo struct

MmaOpApple

struct MmaOpApple[out_type: DType, in_type: DType, num_m_mmas: Int, num_n_mmas: Int, *, b_type: DType = in_type, transpose_a: Bool = False, transpose_b: Bool = False]

Fields

  • rb (UInt16):
  • cb (UInt16):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

AccumType

comptime AccumType = InlineArray[SIMD[out_type, 8], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum]

FRAG_SIZE

comptime FRAG_SIZE = 8

MMA_K

comptime MMA_K = 16

MMA_M

comptime MMA_M = 16

MMA_N

comptime MMA_N = 16

num_accum

comptime num_accum = (num_m_mmas * num_n_mmas)

Methods

__init__

__init__(out self)

zero_accum

static zero_accum() -> MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].AccumType

Returns:

MmaOpApple

mma

mma[bounded: Bool = False](self, mut accum: InlineArray[SIMD[out_type, 8], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum], a_tile: TileTensor[in_type, a_tile.LayoutType, a_tile.origin, address_space=a_tile.address_space, linear_idx_type=a_tile.linear_idx_type, element_size=a_tile.element_size], b_tile: TileTensor[b_type, b_tile.LayoutType, b_tile.origin, address_space=b_tile.address_space, linear_idx_type=b_tile.linear_idx_type, element_size=b_tile.element_size], a_valid_rows: Int = (num_m_mmas * 16), b_valid_cols: Int = (num_n_mmas * 16), k_valid: Int = LayoutType.static_shape[1])

Process K elements across all M/N tile positions.

The K depth is inferred from the A tile's column dimension and must be a multiple of 16. For K=16 this is one MMA step; for K=32 this is two steps, etc. The struct iterates K internally.

Tiles may be row-major or col-major. The stride layout is detected from static_stride and the hardware transpose flag is derived via XOR with the transpose parameter: hw_flag = is_col_major XOR transpose_param.

Use mma() (bounded=False) for interior tiles where all memory is in-bounds. Use mmabounded=True for edge tiles -- zero-fills OOB elements. The kernel should check once per simdgroup, not per load.

Args:

  • accum (InlineArray): Caller-owned InlineArray of SIMD[out_type, 8] accumulators, one per (num_m_mmas * num_n_mmas) tile.
  • a_tile (TileTensor): A operand, shape (num_m_mmas * 16, K).
  • b_tile (TileTensor): B operand, shape (K, num_n_mmas * 16) or (num_n_mmas * 16, K) if transpose_b.
  • a_valid_rows (Int): Valid rows from tile origin (bounded path only).
  • b_valid_cols (Int): Valid cols from tile origin (bounded path only).
  • k_valid (Int): Valid K elements across all steps (bounded path only). Defaults to the tile's full K dimension.

store

store(self, accum: InlineArray[SIMD[out_type, 8], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum], d_tile: TileTensor[out_type, d_tile.LayoutType, d_tile.origin, address_space=d_tile.address_space, linear_idx_type=d_tile.linear_idx_type, element_size=d_tile.element_size])

Store all accumulators to output tile (unconditional).

Caller guarantees all elements are in-bounds.

store_bounded

store_bounded(self, accum: InlineArray[SIMD[out_type, 8], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum], d_tile: TileTensor[out_type, d_tile.LayoutType, d_tile.origin, address_space=d_tile.address_space, linear_idx_type=d_tile.linear_idx_type, element_size=d_tile.element_size], valid_rows: Int, valid_cols: Int)

Store accumulators with bounds checking.

Only writes elements where (row < valid_rows) and (col < valid_cols).

Was this page helpful?