Skip to main content

Mojo struct

TiledMma

struct TiledMma[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int]

Stateless MMA computation on TileTensors.

Direct TileTensor port of TiledTensorCore.mma. Iterates group_size k-steps, indexes A/B register tiles per step, and calls gpu_mma. No register ownership, no SMEM loading — pure computation.

Parameters

  • out_type (DType): Accumulator data type (typically float32).
  • in_type (DType): Input element data type (bfloat16 or float8).
  • shape (IndexList): MMA instruction shape [M, N, K].
  • group_size (Int): Number of k-steps per mma() call.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

a_frag_size

comptime a_frag_size = ((TiledMma[out_type, in_type, shape, group_size].MMA_M * TiledMma[out_type, in_type, shape, group_size].MMA_K) // WARP_SIZE)

c_frag_size

comptime c_frag_size = ((TiledMma[out_type, in_type, shape, group_size].MMA_M * TiledMma[out_type, in_type, shape, group_size].MMA_N) // WARP_SIZE)

MMA_K

comptime MMA_K = shape[2]

MMA_M

comptime MMA_M = shape[0]

MMA_N

comptime MMA_N = shape[1]

Methods

mma

static mma[a_layout: TensorLayout, b_layout: TensorLayout, c_layout: TensorLayout](a_reg: TileTensor[in_type, a_layout, a_reg.origin, address_space=AddressSpace.LOCAL], b_reg: TileTensor[in_type, b_layout, b_reg.origin, address_space=AddressSpace.LOCAL], c_reg: TileTensor[out_type, c_layout, MutExternalOrigin, address_space=AddressSpace.LOCAL])

Execute group_size MMA operations across the K dimension.

Mirrors TiledTensorCore.mma: iterates group_size k-steps, tiles A/B registers per step via vectorize, and accumulates into C.

Parameters:

  • a_layout (TensorLayout): Inferred layout of A register tile.
  • b_layout (TensorLayout): Inferred layout of B register tile.
  • c_layout (TensorLayout): Inferred layout of C register tile.

Args:

  • a_reg (TileTensor): A fragments [num_m_mmas, group_size * a_frag_size].
  • b_reg (TileTensor): B fragments [num_n_mmas, group_size * a_frag_size].
  • c_reg (TileTensor): Accumulator [num_m_mmas, num_n_mmas * c_frag_size].

Was this page helpful?