Mojo struct
AMDPingPongMatmul
struct AMDPingPongMatmul[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, config: KernelConfig, /, enable_swizzle: Bool]
8-warp ping-pong matmul for AMD MI355X.
Warps are split into 2 groups of 4, alternating between load and compute phases for overlapped execution. Uses double-buffered LDS with swizzled access patterns to avoid bank conflicts.
Key features:
- load_to_lds for direct DRAM→LDS transfer (bypasses L1/L2)
- Swizzle pattern for bank-conflict-free LDS access
- Fine-grained lgkmcnt/vmcnt waits for maximum overlap
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
accum_dtype
comptime accum_dtype = get_accum_type[c_type]()
accum_width
comptime accum_width = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].MMA_N) // WARP_SIZE)
BK
comptime BK = config.block_shape.__getitem__[3, DType.int64, Int](2)
BM
comptime BM = config.block_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = config.block_shape.__getitem__[3, DType.int64, Int](1)
LGKM_PER_LOAD_A
comptime LGKM_PER_LOAD_A = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].quadrant_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_k_mmas)
LGKM_PER_LOAD_AB
comptime LGKM_PER_LOAD_AB = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].LGKM_PER_LOAD_A + AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].LGKM_PER_LOAD_B)
LGKM_PER_LOAD_B
comptime LGKM_PER_LOAD_B = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].quadrant_n_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_k_mmas)
load_width
comptime load_width = simd_width_of[a_type]()
loading_threads_4warp
comptime loading_threads_4warp = (4 * WARP_SIZE)
loading_threads_8warp
comptime loading_threads_8warp = (8 * WARP_SIZE)
loads_per_row
comptime loads_per_row = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BK // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].load_width)
MMA_K
comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)
MMA_M
comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)
num_accums
comptime num_accums = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_n_mmas)
num_k_mmas
comptime num_k_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].WK // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].MMA_K)
num_m_mmas
comptime num_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].WM // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].MMA_M)
num_n_mmas
comptime num_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].WN // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].MMA_N)
num_warps_m
comptime num_warps_m = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BM // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].WM)
num_warps_n
comptime num_warps_n = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BN // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].WN)
ping_pong_stages
comptime ping_pong_stages = 2
quadrant_m_mmas
comptime quadrant_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_m_mmas // 2)
quadrant_n_mmas
comptime quadrant_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_n_mmas // 2)
rows_per_iter_4warp
comptime rows_per_iter_4warp = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].loading_threads_4warp // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].loads_per_row)
rows_per_iter_8warp
comptime rows_per_iter_8warp = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].loading_threads_8warp // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].loads_per_row)
total_smem_a
comptime total_smem_a = ((2 * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BM) * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BK)
total_smem_b
comptime total_smem_b = ((2 * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BN) * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BK)
total_warps
comptime total_warps = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_warps_m * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].num_warps_n)
VMCNT_PER_LOAD_A
comptime VMCNT_PER_LOAD_A = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BM // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].rows_per_iter_8warp)
VMCNT_PER_LOAD_A_4WARP
comptime VMCNT_PER_LOAD_A_4WARP = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BM // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].rows_per_iter_4warp)
VMCNT_PER_LOAD_B
comptime VMCNT_PER_LOAD_B = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BN // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].rows_per_iter_8warp)
VMCNT_PER_LOAD_B_4WARP
comptime VMCNT_PER_LOAD_B_4WARP = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].BN // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle].rows_per_iter_4warp)
WK
comptime WK = config.warp_shape.__getitem__[3, DType.int64, Int](2)
WM
comptime WM = config.warp_shape.__getitem__[3, DType.int64, Int](0)
WN
comptime WN = config.warp_shape.__getitem__[3, DType.int64, Int](1)
Methods
validate_config
static validate_config()
Validate the kernel configuration.
matmul_ping_pong
static matmul_ping_pong(a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!