Mojo struct
TiledMma
struct TiledMma[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int]
Stateless MMA computation on TileTensors.
Direct TileTensor port of TiledTensorCore.mma. Iterates group_size k-steps, indexes A/B register tiles per step, and calls gpu_mma. No register ownership, no SMEM loading β pure computation.
Parametersβ
- βout_type (
DType): Accumulator data type (typically float32). - βin_type (
DType): Input element data type (bfloat16 or float8). - βshape (
IndexList[3]): MMA instruction shape [M, N, K]. - βgroup_size (
Int): Number of k-steps per mma() call.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
a_frag_sizeβ
comptime a_frag_size = ((TiledMma[out_type, in_type, shape, group_size].MMA_M * TiledMma[out_type, in_type, shape, group_size].MMA_K) // WARP_SIZE)
c_frag_sizeβ
comptime c_frag_size = ((TiledMma[out_type, in_type, shape, group_size].MMA_M * TiledMma[out_type, in_type, shape, group_size].MMA_N) // WARP_SIZE)
MMA_Kβ
comptime MMA_K = shape[2]
MMA_Mβ
comptime MMA_M = shape[0]
MMA_Nβ
comptime MMA_N = shape[1]
Methodsβ
mmaβ
static mma[a_layout: TensorLayout, b_layout: TensorLayout, c_layout: TensorLayout](a_reg: TileTensor[in_type, a_layout, address_space=AddressSpace.LOCAL], b_reg: TileTensor[in_type, b_layout, address_space=AddressSpace.LOCAL], c_reg: TileTensor[out_type, c_layout, MutExternalOrigin, address_space=AddressSpace.LOCAL])
Execute group_size MMA operations across the K dimension.
Mirrors TiledTensorCore.mma: iterates group_size k-steps, tiles A/B registers per step via vectorize, and accumulates into C.
Parameters:
- βa_layout (
TensorLayout): Inferred layout of A register tile. - βb_layout (
TensorLayout): Inferred layout of B register tile. - βc_layout (
TensorLayout): Inferred layout of C register tile.
Args:
- βa_reg (
TileTensor[in_type, a_layout, address_space=AddressSpace.LOCAL]): A fragments [num_m_mmas, group_size * a_frag_size]. - βb_reg (
TileTensor[in_type, b_layout, address_space=AddressSpace.LOCAL]): B fragments [num_n_mmas, group_size * a_frag_size]. - βc_reg (
TileTensor[out_type, c_layout, MutExternalOrigin, address_space=AddressSpace.LOCAL]): Accumulator [num_m_mmas, num_n_mmas * c_frag_size].
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!