Mojo struct
AMDPingPongMatmul
struct AMDPingPongMatmul[a_type: DType, b_type: DType, c_type: DType, config: KernelConfig, /, enable_swizzle: Bool, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None]
Structured ping-pong matmul for AMD MI355X.
8-warp double-buffered kernel with register-based DRAM→SMEM path.
Parameters
- a_type (
DType): Input A element type. - b_type (
DType): Input B element type. - c_type (
DType): Output C element type. - config (
KernelConfig): KernelConfig with block/warp/mma shapes. - enable_swizzle (
Bool): Enable LDS bank conflict avoidance. - elementwise_lambda_fn (
Optional): Optional epilogue.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
accum_dtype
comptime accum_dtype = get_accum_type[c_type]()
accum_width
comptime accum_width = ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N) // WARP_SIZE)
BK
comptime BK = config.block_shape[2]
BM
comptime BM = config.block_shape[0]
BN
comptime BN = config.block_shape[1]
byte_swizzle
comptime byte_swizzle = Optional(Swizzle((log2_floor((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K // 32)) + 1), log2_floor((16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())) if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() else (log2_floor(((4 * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].simd_width) // 2)) + log2_floor(size_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())), 4)) if enable_swizzle else Optional()
c_frag_size
comptime c_frag_size = ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N) // WARP_SIZE)
half_BM
comptime half_BM = AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM
half_BN
comptime half_BN = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BN // 2)
in_type
comptime in_type = a_type
LGKM_PER_LOAD_A
comptime LGKM_PER_LOAD_A = (((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].quadrant_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_k_mmas) * (((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) // 16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE))) * ceildiv((16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]()), 16))
LGKM_PER_LOAD_B
comptime LGKM_PER_LOAD_B = (((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].quadrant_n_mmas * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_k_mmas) * (((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) // 16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE))) * ceildiv((16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]()), 16))
loads_per_row
comptime loads_per_row = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BK // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].simd_width)
MMA_K
comptime MMA_K = config.mma_shape[2]
MMA_M
comptime MMA_M = config.mma_shape[0]
MMA_N
comptime MMA_N = config.mma_shape[1]
mma_swizzle
comptime mma_swizzle = Optional(AMDPingPongMatmul.make_mma_swizzle()) if enable_swizzle else Optional()
mma_tile_m
comptime mma_tile_m = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM // 2)
mma_tile_n
comptime mma_tile_n = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN // 2)
num_k_mmas
comptime num_k_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BK // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K)
num_m_mmas
comptime num_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M)
num_n_mmas
comptime num_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N)
num_warps_m
comptime num_warps_m = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BM // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM)
num_warps_n
comptime num_warps_n = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BN // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN)
quadrant_m_mmas
comptime quadrant_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_m_mmas // 2)
quadrant_n_mmas
comptime quadrant_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_n_mmas // 2)
rows_per_iter_8warp
comptime rows_per_iter_8warp = ((8 * WARP_SIZE) // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].loads_per_row)
simd_width
comptime simd_width = simd_width_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]()
total_warps
comptime total_warps = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_warps_m * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_warps_n)
VMCNT_PER_LOAD_A
comptime VMCNT_PER_LOAD_A = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].half_BM // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].rows_per_iter_8warp)
VMCNT_PER_LOAD_B
comptime VMCNT_PER_LOAD_B = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].half_BN // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].rows_per_iter_8warp)
WM
comptime WM = config.warp_shape[0]
WN
comptime WN = config.warp_shape[1]
Methods
make_mma_swizzle
static make_mma_swizzle() -> Swizzle
Consumer swizzle for MMA LDS reads (element-space).
AMD MI355X have 64 LDS banks x 4 bytes each. Without swizzling, the MMA thread access pattern causes 4-way bank conflicts. The swizzle XORs high-order address bits into the bank selection bits to distribute accesses across banks.
Swizzle parameters:
- log_tile: Number of bits to XOR, scales with MMA_K.
- base: Log2 of read granularity in bytes (lds_frag_width * elem_size).
- shift: Fixed at 4 for AMD LDS bank geometry.
Configuration examples: BF16 16x16x32: lds_frag=8 bytes=16 -> Swizzle(1, 4, 4) FP8 16x16x128: lds_frag=16 bytes=16 -> Swizzle(3, 4, 4) FP8 32x32x64: lds_frag=32 bytes=32 -> Swizzle(2, 5, 4)
Returns:
Swizzle: Swizzle pattern for bank-conflict-free LDS access.
validate_config
static validate_config()
run
static run[a_layout: TensorLayout, b_layout: TensorLayout, c_layout: TensorLayout](a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin], c: TileTensor[c_type, c_layout, MutAnyOrigin])
Structured ping-pong GEMM kernel entry point.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!