Skip to main content

Mojo struct

AMDPingPongMatmul

struct AMDPingPongMatmul[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, config: KernelConfig, /, enable_swizzle: Bool, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None]

8-warp ping-pong matmul for AMD MI355X.

Warps are split into 2 groups of 4, alternating between load and compute phases for overlapped execution. Uses double-buffered LDS with swizzled access patterns to avoid bank conflicts.

Key features:

  • load_to_lds for direct DRAM→LDS transfer (bypasses L1/L2)
  • Swizzle pattern for bank-conflict-free LDS access
  • Fine-grained lgkmcnt/vmcnt waits for maximum overlap

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

accum_dtype

comptime accum_dtype = get_accum_type[c_type]()

accum_width

comptime accum_width = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_N) // WARP_SIZE)

BK

comptime BK = config.block_shape[2]

BM

comptime BM = config.block_shape[0]

BN

comptime BN = config.block_shape[1]

half_BM

comptime half_BM = config.warp_shape[0]

LGKM_PER_LOAD_A

comptime LGKM_PER_LOAD_A = (((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].quadrant_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].num_k_mmas) * (((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) // 16 if a_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE))) * ceildiv((16 if a_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[a_type]()), 16))

LGKM_PER_LOAD_AB

comptime LGKM_PER_LOAD_AB = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].LGKM_PER_LOAD_A + AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].LGKM_PER_LOAD_B)

LGKM_PER_LOAD_B

comptime LGKM_PER_LOAD_B = (((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].quadrant_n_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].num_k_mmas) * (((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) // 16 if a_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE))) * ceildiv((16 if a_type.is_float8() and (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[a_type]()), 16))

load_width

comptime load_width = simd_width_of[a_type]()

loading_threads_8warp

comptime loading_threads_8warp = (8 * WARP_SIZE)

loads_per_row

comptime loads_per_row = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].BK // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].load_width)

MMA_K

comptime MMA_K = config.mma_shape[2]

MMA_M

comptime MMA_M = config.mma_shape[0]

MMA_N

comptime MMA_N = config.mma_shape[1]

num_accums

comptime num_accums = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].num_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].num_n_mmas)

num_k_mmas

comptime num_k_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].WK // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_K)

num_m_mmas

comptime num_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].WM // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_M)

num_n_mmas

comptime num_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].WN // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].MMA_N)

num_warps_m

comptime num_warps_m = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].BM // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].WM)

num_warps_n

comptime num_warps_n = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].BN // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].WN)

ping_pong_stages

comptime ping_pong_stages = 2

quadrant_m_mmas

comptime quadrant_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].num_m_mmas // 2)

quadrant_n_mmas

comptime quadrant_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].num_n_mmas // 2)

rows_per_iter_8warp

comptime rows_per_iter_8warp = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].loading_threads_8warp // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].loads_per_row)

total_smem_a

comptime total_smem_a = ((2 * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].BM) * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].BK)

total_smem_b

comptime total_smem_b = ((2 * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].BN) * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].BK)

total_warps

comptime total_warps = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].num_warps_m * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].num_warps_n)

VMCNT_PER_LOAD_A

comptime VMCNT_PER_LOAD_A = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].half_BM // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].rows_per_iter_8warp)

VMCNT_PER_LOAD_B

comptime VMCNT_PER_LOAD_B = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].BN // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_swizzle, elementwise_lambda_fn].rows_per_iter_8warp)

WK

comptime WK = config.warp_shape[2]

WM

comptime WM = config.warp_shape[0]

WN

comptime WN = config.warp_shape[1]

Methods

validate_config

static validate_config()

Validate the kernel configuration.

matmul_ping_pong

static matmul_ping_pong(a: LayoutTensor[a_type, a_layout, ImmutAnyOrigin], b: LayoutTensor[b_type, b_layout, ImmutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin])

Was this page helpful?