IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

AMDPingPongMatmul

struct AMDPingPongMatmul[a_type: DType, b_type: DType, c_type: DType, config: KernelConfig, /, enable_swizzle: Bool, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None]

Structured ping-pong matmul for AMD MI355X.

8-warp double-buffered kernel with register-based DRAM→SMEM path.

Parameters​

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

accum_dtype​

comptime accum_dtype = get_accum_type[c_type]()

accum_width​

comptime accum_width = ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N) // WARP_SIZE)

BK​

comptime BK = config.block_shape[2]

BM​

comptime BM = config.block_shape[0]

BN​

comptime BN = config.block_shape[1]

byte_swizzle​

comptime byte_swizzle = Optional(Swizzle((log2_floor((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K // 32)) + 1), log2_floor((16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())) if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() else (log2_floor(((4 * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].simd_width) // 2)) + log2_floor(size_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())), 4)) if enable_swizzle else Optional()

c_frag_size​

comptime c_frag_size = ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N) // WARP_SIZE)

half_BM​

comptime half_BM = AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM

half_BN​

comptime half_BN = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BN // 2)

in_type​

comptime in_type = a_type

LGKM_PER_LOAD_A​

comptime LGKM_PER_LOAD_A = (((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].quadrant_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_k_mmas) * (((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) // 16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE))) * ceildiv((16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]()), 16))

LGKM_PER_LOAD_B​

comptime LGKM_PER_LOAD_B = (((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].quadrant_n_mmas * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_k_mmas) * (((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) // 16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE))) * ceildiv((16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]()), 16))

loads_per_row​

comptime loads_per_row = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BK // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].simd_width)

MMA_K​

comptime MMA_K = config.mma_shape[2]

MMA_M​

comptime MMA_M = config.mma_shape[0]

MMA_N​

comptime MMA_N = config.mma_shape[1]

mma_swizzle​

comptime mma_swizzle = Optional(AMDPingPongMatmul.make_mma_swizzle()) if enable_swizzle else Optional()

mma_tile_m​

comptime mma_tile_m = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM // 2)

mma_tile_n​

comptime mma_tile_n = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN // 2)

num_k_mmas​

comptime num_k_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BK // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K)

num_m_mmas​

comptime num_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M)

num_n_mmas​

comptime num_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N)

num_warps_m​

comptime num_warps_m = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BM // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM)

num_warps_n​

comptime num_warps_n = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BN // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN)

quadrant_m_mmas​

comptime quadrant_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_m_mmas // 2)

quadrant_n_mmas​

comptime quadrant_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_n_mmas // 2)

rows_per_iter_8warp​

comptime rows_per_iter_8warp = ((8 * WARP_SIZE) // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].loads_per_row)

simd_width​

comptime simd_width = simd_width_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]()

total_warps​

comptime total_warps = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_warps_m * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_warps_n)

VMCNT_PER_LOAD_A​

comptime VMCNT_PER_LOAD_A = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].half_BM // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].rows_per_iter_8warp)

VMCNT_PER_LOAD_B​

comptime VMCNT_PER_LOAD_B = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].half_BN // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].rows_per_iter_8warp)

WM​

comptime WM = config.warp_shape[0]

WN​

comptime WN = config.warp_shape[1]

Methods​

make_mma_swizzle​

static def make_mma_swizzle() -> Swizzle

Consumer swizzle for MMA LDS reads (element-space).

AMD MI355X have 64 LDS banks x 4 bytes each. Without swizzling, the MMA thread access pattern causes 4-way bank conflicts. The swizzle XORs high-order address bits into the bank selection bits to distribute accesses across banks.

Swizzle parameters:

  • log_tile: Number of bits to XOR, scales with MMA_K.
  • base: Log2 of read granularity in bytes (lds_frag_width * elem_size).
  • shift: Fixed at 4 for AMD LDS bank geometry.

Configuration examples: BF16 16x16x32: lds_frag=8 bytes=16 -> Swizzle(1, 4, 4) FP8 16x16x128: lds_frag=16 bytes=16 -> Swizzle(3, 4, 4) FP8 32x32x64: lds_frag=32 bytes=32 -> Swizzle(2, 5, 4)

Returns:

Swizzle: Swizzle pattern for bank-conflict-free LDS access.

validate_config​

static def validate_config()

run​

static def run[a_layout: TensorLayout, b_layout: TensorLayout, c_layout: TensorLayout](a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin], c: TileTensor[c_type, c_layout, MutAnyOrigin])

Structured ping-pong GEMM kernel entry point.