IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

QuadrantMmaOp

struct QuadrantMmaOp[out_type: DType, in_type: DType, shape: IndexList[3], k_group_size: Int, num_k_groups: Int, num_m_mmas: Int, num_n_mmas: Int, swizzle: Optional[Swizzle] = None]

MMA operator for AMD matmul ping-pong schedule.

Owns A/B/C register tiles in LOCAL address space. Provides quadrant load/compute methods for the ping-pong double-buffering schedule: load_a_quadrant/load_b_quadrant fill half the register tile from SMEM via load_lds_fragment, then mma_quadrant computes on it.

Parameters​

  • ​out_type (DType): Accumulator data type (typically float32).
  • ​in_type (DType): Input element data type (bfloat16 or float8).
  • ​shape (IndexList[3]): MMA instruction shape [M, N, K].
  • ​k_group_size (Int): Number of MMA K-groups per fragment load.
  • ​num_k_groups (Int): Number of k-groups across the full BK dimension.
  • ​num_m_mmas (Int): MMA tiles along M within the warp tile.
  • ​num_n_mmas (Int): MMA tiles along N within the warp tile.
  • ​swizzle (Optional[Swizzle]): Optional SMEM swizzle for load helpers.

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

accum_width​

comptime accum_width = QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].c_frag_size

BK​

comptime BK = (QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].num_k_mmas * QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_K)

c_frag_size​

comptime c_frag_size = num_matrix_reg[QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_M, QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_N]()

MMA_K​

comptime MMA_K = shape[2]

MMA_M​

comptime MMA_M = shape[0]

MMA_N​

comptime MMA_N = shape[1]

num_k_mmas​

comptime num_k_mmas = (num_k_groups * k_group_size)

quad_m​

comptime quad_m = (num_m_mmas // 2)

quad_n​

comptime quad_n = (num_n_mmas // 2)

quad_WM​

comptime quad_WM = (QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].quad_m * QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_M)

quad_WN​

comptime quad_WN = (QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].quad_n * QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_N)

WM​

comptime WM = (num_m_mmas * QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_M)

WN​

comptime WN = (num_n_mmas * QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_N)

Methods​

__init__​

def __init__(out self)

load_a_quadrant​

def load_a_quadrant[which: Int](self, smem_tile: TileTensor[in_type, address_space=AddressSpace.SHARED])

Load A quadrant which from SMEM sub-tile to registers.

Tiles a_reg as [quad_m, reg_cols](which, 0) to get the register sub-tile for this quadrant, then loads via load_lds_fragment.

load_b_quadrant​

def load_b_quadrant[which: Int](self, smem_tile: TileTensor[in_type, address_space=AddressSpace.SHARED])

Load B quadrant which from SMEM sub-tile to registers.

mma_quadrant​

def mma_quadrant[which_a: Int, which_b: Int](self)

Execute MMA for quadrant (which_a, which_b) via TiledMma.

Slices A/B/C register tiles to the quadrant and delegates to TiledMma for stateless computation.

accum_tile​

def accum_tile(self) -> TileTensor[out_type, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Return the accumulator register tile.

Returns:

TileTensor[out_type, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]