Skip to main content

Mojo struct

MmaOpApple

struct MmaOpApple[out_type: DType, in_type: DType, num_m_mmas: Int, num_n_mmas: Int, *, b_type: DType = in_type, transpose_a: Bool = False, transpose_b: Bool = False]

Fields​

  • ​rb (UInt16):
  • ​cb (UInt16):

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

AccumType​

comptime AccumType = InlineArray[SIMD[out_type, 8], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum]

FRAG_SIZE​

comptime FRAG_SIZE = 8

MMA_K​

comptime MMA_K = 16

MMA_M​

comptime MMA_M = 16

MMA_N​

comptime MMA_N = 16

num_accum​

comptime num_accum = (num_m_mmas * num_n_mmas)

Methods​

__init__​

__init__(out self)

zero_accum​

static zero_accum() -> MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].AccumType

Returns:

MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].AccumType

mma​

mma[bounded: Bool = False](self, mut accum: InlineArray[SIMD[out_type, 8], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum], a_tile: TileTensor[in_type, address_space=a_tile.address_space, linear_idx_type=a_tile.linear_idx_type, element_size=a_tile.element_size], b_tile: TileTensor[b_type, address_space=b_tile.address_space, linear_idx_type=b_tile.linear_idx_type, element_size=b_tile.element_size], a_valid_rows: Int = (num_m_mmas * 16), b_valid_cols: Int = (num_n_mmas * 16), k_valid: Int = LayoutType.static_shape[1])

Process K elements across all M/N tile positions.

The K depth is inferred from the A tile's column dimension and must be a multiple of 16. For K=16 this is one MMA step; for K=32 this is two steps, etc. The struct iterates K internally.

Tiles may be row-major or col-major. The stride layout is detected from static_stride and the hardware transpose flag is derived via XOR with the transpose parameter: hw_flag = is_col_major XOR transpose_param.

Use mma() (bounded=False) for interior tiles where all memory is in-bounds. Use mmabounded=True for edge tiles -- zero-fills OOB elements. The kernel should check once per simdgroup, not per load.

Args:

store​

store(self, accum: InlineArray[SIMD[out_type, 8], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum], d_tile: TileTensor[out_type, address_space=d_tile.address_space, linear_idx_type=d_tile.linear_idx_type, element_size=d_tile.element_size])

Store all accumulators to output tile (unconditional).

Caller guarantees all elements are in-bounds.

store_bounded​

store_bounded(self, accum: InlineArray[SIMD[out_type, 8], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum], d_tile: TileTensor[out_type, address_space=d_tile.address_space, linear_idx_type=d_tile.linear_idx_type, element_size=d_tile.element_size], valid_rows: Int, valid_cols: Int)

Store accumulators with bounds checking.

Only writes elements where (row < valid_rows) and (col < valid_cols).