For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MmaOp
struct MmaOp[out_type: DType, in_type: DType, shape: IndexList[Int(3)], k_group_size: Int, num_k_tiles: Int, num_m_mmas: Int, num_n_mmas: Int, swizzle: Optional[Swizzle] = None]
Register ownership + SMEM loading + schedule API for AMD matmul.
Owns A/B/C register tiles in LOCAL address space. Provides the schedule-facing API: load_frag[k] loads from SMEM to registers, mma[k] delegates to TiledMma for computation.
Parametersβ
- βout_type (
DType): Accumulator data type (typically float32). - βin_type (
DType): Input element data type (bfloat16 or float8). - βshape (
IndexList[Int(3)]): MMA instruction shape [M, N, K]. - βk_group_size (
Int): Number of MMA k-steps per fragment load. - βnum_k_tiles (
Int): Number of k-tiles across the warp K dimension. - βnum_m_mmas (
Int): MMA tiles along M within the warp tile. - βnum_n_mmas (
Int): MMA tiles along N within the warp tile. - βswizzle (
Optional[Swizzle]): Optional SMEM swizzle for fragment loading.
Implemented traitsβ
comptime membersβ
c_frag_sizeβ
comptime c_frag_size = (Int((mul shape[Int(0)], shape[Int(1)])) // _resolve_warp_size())
k_tile_sizeβ
comptime k_tile_size = (shape[Int(2)] * k_group_size)
MMA_Kβ
comptime MMA_K = shape[Int(2)]
MMA_Mβ
comptime MMA_M = shape[Int(0)]
MMA_Nβ
comptime MMA_N = shape[Int(1)]
simd_widthβ
comptime simd_width = (k_group_size * (Int((mul shape[Int(0)], shape[Int(2)])) // _resolve_warp_size()))
WMβ
comptime WM = (num_m_mmas * shape[Int(0)])
WNβ
comptime WN = (num_n_mmas * shape[Int(1)])
Methodsβ
__init__β
def __init__(out self)
accum_tileβ
def accum_tile(self) -> ref[self._c_reg] TileTensor[out_type, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
Returns:
ref[self._c_reg] TileTensor[out_type, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
load_fragβ
def load_frag[k_tile_idx: Int](self, a_smem_warp: TileTensor[in_type, address_space=AddressSpace.SHARED], b_smem_warp: TileTensor[in_type, address_space=AddressSpace.SHARED])
Load A and B MMA fragments for k-tile k_tile_idx from SMEM.
Expects block-local warp tiles of shape WM x k_tile_size (or WN x k_tile_size), where each k-tile block is contiguous in SMEM (blocked_product layout). Uses direct distribute with swizzle β correct because each block starts at a swizzle-aligned offset.
mmaβ
def mma[k_tile_idx: Int](self)
Execute MMA for k-tile k_tile_idx via TiledMma.
Slices A/B registers for this k-tile and delegates to TiledMma.mma for stateless computation.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!