Skip to main content

Mojo struct

AMDPingPongMatmul

struct AMDPingPongMatmul[a_type: DType, b_type: DType, c_type: DType, config: KernelConfig, /, enable_swizzle: Bool, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None]

Structured ping-pong matmul for AMD MI355X.

8-warp double-buffered kernel with register-based DRAM→SMEM path.

Parameters

  • a_type (DType): Input A element type.
  • b_type (DType): Input B element type.
  • c_type (DType): Output C element type.
  • config (KernelConfig): KernelConfig with block/warp/mma shapes.
  • enable_swizzle (Bool): Enable LDS bank conflict avoidance.
  • elementwise_lambda_fn (Optional): Optional epilogue.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

accum_dtype

comptime accum_dtype = get_accum_type[c_type]()

accum_width

comptime accum_width = ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N) // WARP_SIZE)

BK

comptime BK = config.block_shape[2]

BM

comptime BM = config.block_shape[0]

BN

comptime BN = config.block_shape[1]

byte_swizzle

comptime byte_swizzle = Optional(Swizzle((log2_floor((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K // 32)) + 1), log2_floor((16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())) if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() else (log2_floor(((4 * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].simd_width) // 2)) + log2_floor(size_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())), 4)) if enable_swizzle else Optional()

c_frag_size

comptime c_frag_size = ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N) // WARP_SIZE)

half_BM

comptime half_BM = AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM

half_BN

comptime half_BN = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BN // 2)

in_type

comptime in_type = a_type

LGKM_PER_LOAD_A

comptime LGKM_PER_LOAD_A = (((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].quadrant_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_k_mmas) * (((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) // 16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE))) * ceildiv((16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]()), 16))

LGKM_PER_LOAD_B

comptime LGKM_PER_LOAD_B = (((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].quadrant_n_mmas * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_k_mmas) * (((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) // 16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE))) * ceildiv((16 if AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]()), 16))

loads_per_row

comptime loads_per_row = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BK // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].simd_width)

MMA_K

comptime MMA_K = config.mma_shape[2]

MMA_M

comptime MMA_M = config.mma_shape[0]

MMA_N

comptime MMA_N = config.mma_shape[1]

mma_swizzle

comptime mma_swizzle = Optional(AMDPingPongMatmul.make_mma_swizzle()) if enable_swizzle else Optional()

mma_tile_m

comptime mma_tile_m = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM // 2)

mma_tile_n

comptime mma_tile_n = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN // 2)

num_k_mmas

comptime num_k_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BK // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K)

num_m_mmas

comptime num_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M)

num_n_mmas

comptime num_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N)

num_warps_m

comptime num_warps_m = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BM // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM)

num_warps_n

comptime num_warps_n = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BN // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN)

quadrant_m_mmas

comptime quadrant_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_m_mmas // 2)

quadrant_n_mmas

comptime quadrant_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_n_mmas // 2)

rows_per_iter_8warp

comptime rows_per_iter_8warp = ((8 * WARP_SIZE) // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].loads_per_row)

simd_width

comptime simd_width = simd_width_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]()

total_warps

comptime total_warps = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_warps_m * AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_warps_n)

VMCNT_PER_LOAD_A

comptime VMCNT_PER_LOAD_A = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].half_BM // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].rows_per_iter_8warp)

VMCNT_PER_LOAD_B

comptime VMCNT_PER_LOAD_B = (AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].half_BN // AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].rows_per_iter_8warp)

WM

comptime WM = config.warp_shape[0]

WN

comptime WN = config.warp_shape[1]

Methods

make_mma_swizzle

static make_mma_swizzle() -> Swizzle

Consumer swizzle for MMA LDS reads (element-space).

AMD MI355X have 64 LDS banks x 4 bytes each. Without swizzling, the MMA thread access pattern causes 4-way bank conflicts. The swizzle XORs high-order address bits into the bank selection bits to distribute accesses across banks.

Swizzle parameters:

  • log_tile: Number of bits to XOR, scales with MMA_K.
  • base: Log2 of read granularity in bytes (lds_frag_width * elem_size).
  • shift: Fixed at 4 for AMD LDS bank geometry.

Configuration examples: BF16 16x16x32: lds_frag=8 bytes=16 -> Swizzle(1, 4, 4) FP8 16x16x128: lds_frag=16 bytes=16 -> Swizzle(3, 4, 4) FP8 32x32x64: lds_frag=32 bytes=32 -> Swizzle(2, 5, 4)

Returns:

Swizzle: Swizzle pattern for bank-conflict-free LDS access.

validate_config

static validate_config()

run

static run[a_layout: TensorLayout, b_layout: TensorLayout, c_layout: TensorLayout](a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin], c: TileTensor[c_type, c_layout, MutAnyOrigin])

Structured ping-pong GEMM kernel entry point.

Was this page helpful?