Mojo struct
QuadrantMmaOp
struct QuadrantMmaOp[out_type: DType, in_type: DType, shape: IndexList[3], k_group_size: Int, num_k_groups: Int, num_m_mmas: Int, num_n_mmas: Int, swizzle: Optional[Swizzle] = None]
MMA operator for AMD matmul ping-pong schedule.
Owns A/B/C register tiles in LOCAL address space. Provides quadrant load/compute methods for the ping-pong double-buffering schedule: load_a_quadrant/load_b_quadrant fill half the register tile from SMEM via load_lds_fragment, then mma_quadrant computes on it.
Parametersβ
- βout_type (
DType): Accumulator data type (typically float32). - βin_type (
DType): Input element data type (bfloat16 or float8). - βshape (
IndexList[3]): MMA instruction shape [M, N, K]. - βk_group_size (
Int): Number of MMA K-groups per fragment load. - βnum_k_groups (
Int): Number of k-groups across the full BK dimension. - βnum_m_mmas (
Int): MMA tiles along M within the warp tile. - βnum_n_mmas (
Int): MMA tiles along N within the warp tile. - βswizzle (
Optional[Swizzle]): Optional SMEM swizzle for load helpers.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
accum_widthβ
comptime accum_width = QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].c_frag_size
BKβ
comptime BK = (QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].num_k_mmas * QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_K)
c_frag_sizeβ
comptime c_frag_size = num_matrix_reg[QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_M, QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_N]()
MMA_Kβ
comptime MMA_K = shape[2]
MMA_Mβ
comptime MMA_M = shape[0]
MMA_Nβ
comptime MMA_N = shape[1]
num_k_mmasβ
comptime num_k_mmas = (num_k_groups * k_group_size)
quad_mβ
comptime quad_m = (num_m_mmas // 2)
quad_nβ
comptime quad_n = (num_n_mmas // 2)
quad_WMβ
comptime quad_WM = (QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].quad_m * QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_M)
quad_WNβ
comptime quad_WN = (QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].quad_n * QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_N)
WMβ
comptime WM = (num_m_mmas * QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_M)
WNβ
comptime WN = (num_n_mmas * QuadrantMmaOp[out_type, in_type, shape, k_group_size, num_k_groups, num_m_mmas, num_n_mmas, swizzle].MMA_N)
Methodsβ
__init__β
__init__(out self)
load_a_quadrantβ
load_a_quadrant[which: Int](self, smem_tile: TileTensor[in_type, address_space=AddressSpace.SHARED])
Load A quadrant which from SMEM sub-tile to registers.
Tiles a_reg as [quad_m, reg_cols](which, 0) to get the register sub-tile for this quadrant, then loads via load_lds_fragment.
load_b_quadrantβ
load_b_quadrant[which: Int](self, smem_tile: TileTensor[in_type, address_space=AddressSpace.SHARED])
Load B quadrant which from SMEM sub-tile to registers.
mma_quadrantβ
mma_quadrant[which_a: Int, which_b: Int](self)
Execute MMA for quadrant (which_a, which_b) via TiledMma.
Slices A/B/C register tiles to the quadrant and delegates to TiledMma for stateless computation.
accum_tileβ
accum_tile(self) -> TileTensor[out_type, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Return the accumulator register tile.
Returns:
TileTensor[out_type, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!