IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MmaOpApple

struct MmaOpApple[out_type: DType, in_type: DType, num_m_mmas: Int, num_n_mmas: Int, *, b_type: DType = in_type, transpose_a: Bool = False, transpose_b: Bool = False]

Fields​

  • ​rb (Int):
  • ​cb (Int):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

AccumType​

comptime AccumType = InlineArray[SIMD[out_type, SIMDSize(8)], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum]

FRAG_SIZE​

comptime FRAG_SIZE = 8

MMA_K​

comptime MMA_K = 16

MMA_M​

comptime MMA_M = 16

MMA_N​

comptime MMA_N = 16

num_accum​

comptime num_accum = (num_m_mmas * num_n_mmas)

Methods​

__init__​

def __init__(out self)

zero_accum​

static def zero_accum() -> Self.AccumType

Returns:

Self.AccumType

mma​

def mma[bounded: Bool = False](self, mut accum: InlineArray[SIMD[out_type, SIMDSize(8)], Self.num_accum], a_tile: TileTensor[in_type, Storage=a_tile.Storage, address_space=a_tile.address_space, linear_idx_type=a_tile.linear_idx_type, element_size=a_tile.element_size], b_tile: TileTensor[b_type, Storage=b_tile.Storage, address_space=b_tile.address_space, linear_idx_type=b_tile.linear_idx_type, element_size=b_tile.element_size], a_valid_rows: Int = (num_m_mmas * Int(16)), b_valid_cols: Int = (num_n_mmas * Int(16)), k_valid: Int = a_tile.LayoutType.static_shape[Int(1)])

Process K elements across all M/N tile positions.

The K depth is inferred from the A tile's column dimension and must be a multiple of 16. For K=16 this is one MMA step; for K=32 this is two steps, etc. The struct iterates K internally.

Tiles may be row-major or col-major. The stride layout is detected from static_stride and the hardware transpose flag is derived via XOR with the transpose parameter: hw_flag = is_col_major XOR transpose_param.

Use mma() (bounded=False) for interior tiles where all memory is in-bounds. Use mmabounded=True for edge tiles -- zero-fills OOB elements. The kernel should check once per simdgroup, not per load.

Args:

store​

def store(self, accum: InlineArray[SIMD[out_type, SIMDSize(8)], Self.num_accum], d_tile: TileTensor[out_type, Storage=d_tile.Storage, address_space=d_tile.address_space, linear_idx_type=d_tile.linear_idx_type, element_size=d_tile.element_size])

Store all accumulators to output tile (unconditional).

Caller guarantees all elements are in-bounds.

store_bounded​

def store_bounded(self, accum: InlineArray[SIMD[out_type, SIMDSize(8)], Self.num_accum], d_tile: TileTensor[out_type, Storage=d_tile.Storage, address_space=d_tile.address_space, linear_idx_type=d_tile.linear_idx_type, element_size=d_tile.element_size], valid_rows: Int, valid_cols: Int)

Stores accumulators where (row < valid_rows) and (col < valid_cols).

Assumes row-major d_tile; for col-major, mirror _do_load's swap.