Skip to main content

Mojo struct

AMDPingPongMatmul

struct AMDPingPongMatmul[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, config: KernelConfig, /, enable_l2_cache_optimization: Bool, enable_swizzle: Bool, use_transpose_load: Bool]

High-level ping-pong matmul implementation for AMD GPUs.

This implements the 8-warp ping-pong pattern where warps alternate between loading data and computing, achieving overlapped execution.

Memory Layout Strategy for Bank Conflict Avoidance:

  1. Shared Memory Organization (AMD MI355 has 64 banks, 4 bytes each):

    • Uses double-buffered shared memory (ping-pong buffers)
    • Each buffer holds BM×BK elements for A, BN×BK for B
  2. Bank Conflict Avoidance Pattern:

    • Bank index = (address / 4) % 64
    • Swizzled access pattern distributes consecutive thread accesses across banks
    • Column swizzle: (lane_id % 4) * load_width spreads within 32 bytes
    • Row stride: (lane_id // 4) * K ensures different rows map to different banks
    • Warp-level offsets further distribute accesses
  3. Load Pattern (Global → Shared Memory):

    • Uses AMD's load_to_lds instruction for direct DRAM→LDS transfer
    • Bypasses L1/L2 caches for lower latency
    • Coalesced global memory access (consecutive threads → consecutive addresses)
    • Bank-conflict-free shared memory writes via swizzled offsets
  4. MMA Access Pattern (Shared Memory → Registers):

    • Optimized for AMD's matrix cores (4 per CU on MI355)
    • 16×4 thread layout within each warp for MMA fragments
    • Ensures all 4 matrix cores stay busy throughout execution

Implemented traits

AnyType, UnknownDestructibility

comptime members

__del__is_trivial

comptime __del__is_trivial = True

accum_dtype

comptime accum_dtype = get_accum_type[c_type]()

accum_width

comptime accum_width = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].MMA_N) // WARP_SIZE)

BK

comptime BK = config.block_shape.__getitem__[3, DType.int64, Int](2)

BM

comptime BM = config.block_shape.__getitem__[3, DType.int64, Int](0)

BN

comptime BN = config.block_shape.__getitem__[3, DType.int64, Int](1)

LGKM_PER_LOAD_A

comptime LGKM_PER_LOAD_A = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].quadrant_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_k_mmas)

LGKM_PER_LOAD_AB

comptime LGKM_PER_LOAD_AB = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].LGKM_PER_LOAD_A + AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].LGKM_PER_LOAD_B)

LGKM_PER_LOAD_B

comptime LGKM_PER_LOAD_B = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].quadrant_n_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_k_mmas)

load_width

comptime load_width = simd_width_of[a_type]()

loading_threads_4warp

comptime loading_threads_4warp = (4 * WARP_SIZE)

loading_threads_8warp

comptime loading_threads_8warp = (8 * WARP_SIZE)

loads_per_row

comptime loads_per_row = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BK // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].load_width)

MMA_K

comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)

MMA_M

comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)

MMA_N

comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)

num_accums

comptime num_accums = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_n_mmas)

num_k_mmas

comptime num_k_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].WK // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].MMA_K)

num_m_mmas

comptime num_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].WM // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].MMA_M)

num_n_mmas

comptime num_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].WN // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].MMA_N)

num_warps_m

comptime num_warps_m = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BM // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].WM)

num_warps_n

comptime num_warps_n = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BN // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].WN)

ping_pong_stages

comptime ping_pong_stages = 2

quadrant_m_mmas

comptime quadrant_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_m_mmas // 2)

quadrant_n_mmas

comptime quadrant_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_n_mmas // 2)

rows_per_iter_4warp

comptime rows_per_iter_4warp = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].loading_threads_4warp // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].loads_per_row)

rows_per_iter_8warp

comptime rows_per_iter_8warp = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].loading_threads_8warp // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].loads_per_row)

total_smem_a

comptime total_smem_a = ((2 * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BM) * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BK)

total_smem_b

comptime total_smem_b = ((2 * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BN) * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BK)

total_warps

comptime total_warps = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_warps_m * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_warps_n)

VMCNT_PER_LOAD_A

comptime VMCNT_PER_LOAD_A = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BM // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].rows_per_iter_8warp)

VMCNT_PER_LOAD_A_4WARP

comptime VMCNT_PER_LOAD_A_4WARP = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BM // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].rows_per_iter_4warp)

VMCNT_PER_LOAD_B

comptime VMCNT_PER_LOAD_B = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BN // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].rows_per_iter_8warp)

VMCNT_PER_LOAD_B_4WARP

comptime VMCNT_PER_LOAD_B_4WARP = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BN // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].rows_per_iter_4warp)

WK

comptime WK = config.warp_shape.__getitem__[3, DType.int64, Int](2)

WM

comptime WM = config.warp_shape.__getitem__[3, DType.int64, Int](0)

WN

comptime WN = config.warp_shape.__getitem__[3, DType.int64, Int](1)

Methods

validate_config

static validate_config()

Validate the kernel configuration.

matmul_demo_ping_pong

static matmul_demo_ping_pong(a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin])

Was this page helpful?