For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
AMDPingPongMatmul
struct AMDPingPongMatmul[a_type: DType, b_type: DType, c_type: DType, config: KernelConfig, /, enable_swizzle: Bool, elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None]
Structured ping-pong matmul for AMD MI355X.
8-warp double-buffered kernel with register-based DRAMβSMEM path.
Parametersβ
- βa_type (
DType): Input A element type. - βb_type (
DType): Input B element type. - βc_type (
DType): Output C element type. - βconfig (
KernelConfig): KernelConfig with block/warp/mma shapes. - βenable_swizzle (
Bool): Enable LDS bank conflict avoidance. - βelementwise_lambda_fn (
Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None]): Optional epilogue.
Implemented traitsβ
comptime membersβ
accum_dtypeβ
comptime accum_dtype = get_accum_type[c_type]()
accum_widthβ
comptime accum_width = (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(1)])) // _resolve_warp_size())
BKβ
comptime BK = config.block_shape[Int(2)]
BMβ
comptime BM = config.block_shape[Int(0)]
BNβ
comptime BN = config.block_shape[Int(1)]
byte_swizzleβ
comptime byte_swizzle = Optional(Swizzle(Int((add log2_floor((config.mma_shape[Int(2)] // Int(32))), 1)), log2_floor(Int((mul size_of[a_type](), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 78) and (eq config.mma_shape[Int(0)], 16) and (eq config.mma_shape[Int(2)], 128) else (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size())))) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 78) else Int((add log2_floor((Int((mul simd_width_of[a_type](), 4)) // Int(2))), log2_floor(size_of[a_type]()))), Int(4))) if enable_swizzle else Optional()
c_frag_sizeβ
comptime c_frag_size = (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(1)])) // _resolve_warp_size())
half_BMβ
comptime half_BM = AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM
half_BNβ
comptime half_BN = (config.block_shape[Int(1)] // Int(2))
in_typeβ
comptime in_type = a_type
LGKM_PER_LOAD_Aβ
comptime LGKM_PER_LOAD_A = (Int((mul (config.block_shape[Int(2)] // config.mma_shape[Int(2)]), ((config.warp_shape[Int(0)] // config.mma_shape[Int(0)]) // Int(2)), ((Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size()) // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 78) and (eq config.mma_shape[Int(0)], 16) and (eq config.mma_shape[Int(2)], 128) else (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size())))) * ceildiv(Int((mul size_of[a_type](), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 78) and (eq config.mma_shape[Int(0)], 16) and (eq config.mma_shape[Int(2)], 128) else (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size()))), Int(16)))
LGKM_PER_LOAD_Bβ
comptime LGKM_PER_LOAD_B = (Int((mul (config.block_shape[Int(2)] // config.mma_shape[Int(2)]), ((config.warp_shape[Int(1)] // config.mma_shape[Int(1)]) // Int(2)), ((Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size()) // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 78) and (eq config.mma_shape[Int(0)], 16) and (eq config.mma_shape[Int(2)], 128) else (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size())))) * ceildiv(Int((mul size_of[a_type](), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 78) and (eq config.mma_shape[Int(0)], 16) and (eq config.mma_shape[Int(2)], 128) else (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size()))), Int(16)))
loads_per_rowβ
comptime loads_per_row = (config.block_shape[Int(2)] // simd_width_of[a_type]())
MMA_Kβ
comptime MMA_K = config.mma_shape[Int(2)]
MMA_Mβ
comptime MMA_M = config.mma_shape[Int(0)]
MMA_Nβ
comptime MMA_N = config.mma_shape[Int(1)]
mma_swizzleβ
comptime mma_swizzle = Optional(AMDPingPongMatmul.make_mma_swizzle()) if enable_swizzle else Optional()
mma_tile_mβ
comptime mma_tile_m = (config.warp_shape[Int(0)] // Int(2))
mma_tile_nβ
comptime mma_tile_n = (config.warp_shape[Int(1)] // Int(2))
num_k_mmasβ
comptime num_k_mmas = (config.block_shape[Int(2)] // config.mma_shape[Int(2)])
num_m_mmasβ
comptime num_m_mmas = (config.warp_shape[Int(0)] // config.mma_shape[Int(0)])
num_n_mmasβ
comptime num_n_mmas = (config.warp_shape[Int(1)] // config.mma_shape[Int(1)])
num_warps_mβ
comptime num_warps_m = (config.block_shape[Int(0)] // config.warp_shape[Int(0)])
num_warps_nβ
comptime num_warps_n = (config.block_shape[Int(1)] // config.warp_shape[Int(1)])
quadrant_m_mmasβ
comptime quadrant_m_mmas = ((config.warp_shape[Int(0)] // config.mma_shape[Int(0)]) // Int(2))
quadrant_n_mmasβ
comptime quadrant_n_mmas = ((config.warp_shape[Int(1)] // config.mma_shape[Int(1)]) // Int(2))
rows_per_iter_8warpβ
comptime rows_per_iter_8warp = (Int((mul _resolve_warp_size(), 8)) // (config.block_shape[Int(2)] // simd_width_of[a_type]()))
simd_widthβ
comptime simd_width = simd_width_of[AMDPingPongMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]()
total_warpsβ
comptime total_warps = ((config.block_shape[Int(0)] // config.warp_shape[Int(0)]) * (config.block_shape[Int(1)] // config.warp_shape[Int(1)]))
VMCNT_PER_LOAD_Aβ
comptime VMCNT_PER_LOAD_A = (config.warp_shape[Int(0)] // (Int((mul _resolve_warp_size(), 8)) // (config.block_shape[Int(2)] // simd_width_of[a_type]())))
VMCNT_PER_LOAD_Bβ
comptime VMCNT_PER_LOAD_B = ((config.block_shape[Int(1)] // Int(2)) // (Int((mul _resolve_warp_size(), 8)) // (config.block_shape[Int(2)] // simd_width_of[a_type]())))
WMβ
comptime WM = config.warp_shape[Int(0)]
WNβ
comptime WN = config.warp_shape[Int(1)]
Methodsβ
make_mma_swizzleβ
static def make_mma_swizzle() -> Swizzle
Consumer swizzle for MMA LDS reads (element-space).
AMD MI355X have 64 LDS banks x 4 bytes each. Without swizzling, the MMA thread access pattern causes 4-way bank conflicts. The swizzle XORs high-order address bits into the bank selection bits to distribute accesses across banks.
Swizzle parameters:
- log_tile: Number of bits to XOR, scales with MMA_K.
- base: Log2 of read granularity in bytes (lds_frag_width * elem_size).
- shift: Fixed at 4 for AMD LDS bank geometry.
Configuration examples: BF16 16x16x32: lds_frag=8 bytes=16 -> Swizzle(1, 4, 4) FP8 16x16x128: lds_frag=16 bytes=16 -> Swizzle(3, 4, 4) FP8 32x32x64: lds_frag=32 bytes=32 -> Swizzle(2, 5, 4)
Returns:
Swizzle: Swizzle pattern for bank-conflict-free LDS access.
validate_configβ
static def validate_config()
runβ
static def run[a_layout: TensorLayout, b_layout: TensorLayout, c_layout: TensorLayout](a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin], c: TileTensor[c_type, c_layout, MutAnyOrigin])
Structured ping-pong GEMM kernel entry point.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!