Mojo struct
MmaOpApple
struct MmaOpApple[out_type: DType, in_type: DType, num_m_mmas: Int, num_n_mmas: Int, *, b_type: DType = in_type, transpose_a: Bool = False, transpose_b: Bool = False]
Fields
- rb (
UInt16): - cb (
UInt16):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
AccumType
comptime AccumType = InlineArray[SIMD[out_type, 8], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum]
FRAG_SIZE
comptime FRAG_SIZE = 8
MMA_K
comptime MMA_K = 16
MMA_M
comptime MMA_M = 16
MMA_N
comptime MMA_N = 16
num_accum
comptime num_accum = (num_m_mmas * num_n_mmas)
Methods
__init__
__init__(out self)
zero_accum
static zero_accum() -> MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].AccumType
Returns:
MmaOpApple
mma
mma[bounded: Bool = False](self, mut accum: InlineArray[SIMD[out_type, 8], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum], a_tile: TileTensor[in_type, a_tile.LayoutType, a_tile.origin, address_space=a_tile.address_space, linear_idx_type=a_tile.linear_idx_type, element_size=a_tile.element_size], b_tile: TileTensor[b_type, b_tile.LayoutType, b_tile.origin, address_space=b_tile.address_space, linear_idx_type=b_tile.linear_idx_type, element_size=b_tile.element_size], a_valid_rows: Int = (num_m_mmas * 16), b_valid_cols: Int = (num_n_mmas * 16), k_valid: Int = LayoutType.static_shape[1])
Process K elements across all M/N tile positions.
The K depth is inferred from the A tile's column dimension and must be a multiple of 16. For K=16 this is one MMA step; for K=32 this is two steps, etc. The struct iterates K internally.
Tiles may be row-major or col-major. The stride layout is detected from static_stride and the hardware transpose flag is derived via XOR with the transpose parameter: hw_flag = is_col_major XOR transpose_param.
Use mma() (bounded=False) for interior tiles where all memory is in-bounds. Use mmabounded=True for edge tiles -- zero-fills OOB elements. The kernel should check once per simdgroup, not per load.
Args:
- accum (
InlineArray): Caller-owned InlineArray of SIMD[out_type, 8] accumulators, one per (num_m_mmas * num_n_mmas) tile. - a_tile (
TileTensor): A operand, shape (num_m_mmas * 16, K). - b_tile (
TileTensor): B operand, shape (K, num_n_mmas * 16) or (num_n_mmas * 16, K) if transpose_b. - a_valid_rows (
Int): Valid rows from tile origin (bounded path only). - b_valid_cols (
Int): Valid cols from tile origin (bounded path only). - k_valid (
Int): Valid K elements across all steps (bounded path only). Defaults to the tile's full K dimension.
store
store(self, accum: InlineArray[SIMD[out_type, 8], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum], d_tile: TileTensor[out_type, d_tile.LayoutType, d_tile.origin, address_space=d_tile.address_space, linear_idx_type=d_tile.linear_idx_type, element_size=d_tile.element_size])
Store all accumulators to output tile (unconditional).
Caller guarantees all elements are in-bounds.
store_bounded
store_bounded(self, accum: InlineArray[SIMD[out_type, 8], MmaOpApple[out_type, in_type, num_m_mmas, num_n_mmas, b_type=b_type, transpose_a=transpose_a, transpose_b=transpose_b].num_accum], d_tile: TileTensor[out_type, d_tile.LayoutType, d_tile.origin, address_space=d_tile.address_space, linear_idx_type=d_tile.linear_idx_type, element_size=d_tile.element_size], valid_rows: Int, valid_cols: Int)
Store accumulators with bounds checking.
Only writes elements where (row < valid_rows) and (col < valid_cols).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!