Skip to main content

Mojo struct

AMD4WaveMatmul

struct AMD4WaveMatmul[a_type: DType, b_type: DType, c_type: DType, config: KernelConfig, /, enable_swizzle: Bool, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None]

Hand-written 4-warp 2x2 inline-MMA matmul for AMD MI355X.

Line-by-line port of HipKittens FP8_4wave's matmul_device_1024 (BM=64) and matmul_device_2048 (BM=128). No declarative pipeline framework β€” explicit waits, explicit barriers, explicit register rotation matching the source's k+2 prefetch pattern.

Parameters​

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

accum_dtype​

comptime accum_dtype = get_accum_type[c_type]()

Accumulator dtype derived from c_type.

BK​

comptime BK = config.block_shape[2]

Workgroup K-tile size.

BM​

comptime BM = config.block_shape[0]

Workgroup M-tile size.

BN​

comptime BN = config.block_shape[1]

Workgroup N-tile size.

byte_swizzle​

comptime byte_swizzle = Optional(Swizzle((log2_floor((AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K // 32)) + 1), log2_floor((16 if AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())) if AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() else (log2_floor(((4 * AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].simd_width) // 2)) + log2_floor(size_of[AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())), 4)) if enable_swizzle else Optional()

Optional producer-side (byte-space) SMEM-store swizzle.

c_frag_size​

comptime c_frag_size = ((AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N) // WARP_SIZE)

Per-lane output-fragment width for one MMA.

half_BM​

comptime half_BM = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BM // 2)

Half of BM β€” one M-subtile per SMEM stage.

half_BN​

comptime half_BN = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BN // 2)

Half of BN β€” one N-subtile per SMEM stage.

in_type​

comptime in_type = a_type

Input element type (A and B share a type for FP8).

MMA_K​

comptime MMA_K = config.mma_shape[2]

Single-MMA K dimension.

MMA_M​

comptime MMA_M = config.mma_shape[0]

Single-MMA M dimension.

MMA_N​

comptime MMA_N = config.mma_shape[1]

Single-MMA N dimension.

mma_swizzle​

comptime mma_swizzle = Optional(Swizzle((log2_floor((AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K // 32)) + 1), log2_floor((16 if AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())), 4)) if enable_swizzle else Optional()

Optional consumer-side MMA swizzle, populated when swizzling is on.

mma_tile_m​

comptime mma_tile_m = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM // 2)

Per-quadrant M-tile size consumed by an MMA load.

mma_tile_n​

comptime mma_tile_n = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN // 2)

Per-quadrant N-tile size consumed by an MMA load.

num_k_mmas​

comptime num_k_mmas = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BK // AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K)

Number of MMAs per warp in the K dimension.

num_m_mmas​

comptime num_m_mmas = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM // AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M)

Number of MMAs per warp in the M dimension.

num_n_mmas​

comptime num_n_mmas = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN // AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N)

Number of MMAs per warp in the N dimension.

num_warps_m​

comptime num_warps_m = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BM // AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM)

Number of warps in the M dimension of the workgroup grid.

num_warps_n​

comptime num_warps_n = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BN // AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN)

Number of warps in the N dimension of the workgroup grid.

quadrant_m_mmas​

comptime quadrant_m_mmas = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_m_mmas // 2)

Number of MMAs per quadrant in the M dimension.

quadrant_n_mmas​

comptime quadrant_n_mmas = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_n_mmas // 2)

Number of MMAs per quadrant in the N dimension.

simd_width​

comptime simd_width = (16 // size_of[AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())

SIMD lane width matched to AMD buffer_load_lds's 16-byte transaction.

Target-independent (unlike simd_width_of[in_type](), which would return the host's SIMD width when this struct is instantiated from a comptime context running on the CPU host). validate_config asserts that size_of[in_type] divides 16 evenly; for any dtype that does (FP8/BF16/FP16/FP32) this gives the same value as simd_width_of inside an AMDGPU kernel.

SMEM_BYTES​

comptime SMEM_BYTES = (((2 * (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BM + AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BN)) * AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BK) * size_of[AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())

SMEM footprint per workgroup (bytes), derived from BM/BN/BK/in_type.

total_warps​

comptime total_warps = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_warps_m * AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_warps_n)

Total warps per workgroup (must be 4 for this kernel).

VMCNT_PER_LOAD_A​

comptime VMCNT_PER_LOAD_A = AMD4WaveMatmul._build_geometry().vm_per_load_a

Global-load vmcnt cost per A prefetch (from KernelGeometry).

VMCNT_PER_LOAD_B​

comptime VMCNT_PER_LOAD_B = AMD4WaveMatmul._build_geometry().vm_per_load_b

Global-load vmcnt cost per B prefetch (from KernelGeometry).

WM​

comptime WM = config.warp_shape[0]

Per-warp M-tile size.

WN​

comptime WN = config.warp_shape[1]

Per-warp N-tile size.

Methods​

is_valid_config​

static is_valid_config() -> Bool

Returns whether the kernel's tile shapes are a viable config.

Pure predicate β€” does not raise. Use this from autotune drivers and dispatcher fallbacks to filter out impossible (BM, BN, BK, dtype) combinations before instantiating the kernel. The full per-check set is in validate_config; this is its non-throwing counterpart.

Returns:

Bool: True if every structural and resource invariant holds.

validate_config​

static validate_config()

Asserts that the kernel's tile shapes meet 4-wave invariants.

Throws (via comptime assert) on the first failing invariant, with a check-specific message. Called from run so any kernel instantiation that compiles is guaranteed valid. For non- throwing tests (autotune sweeps, dispatcher pre-flight), use is_valid_config instead.

run​

static run[a_layout: TensorLayout, b_layout: TensorLayout, c_layout: TensorLayout, *, use_framework_schedule: Bool = False, num_splits: Int = 1](a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin], c: TileTensor[c_type, c_layout, MutAnyOrigin])

Runs the 4-wave GEMM kernel for one workgroup tile.

Body strategy is selected at comptime via use_framework_schedule:

  • False (default): hand-written _run_iter body. Explicit waits, barriers, and cross-stage register rotation matching HipKittens' FP8_4wave reference. Currently the perf champion (~516 TFLOPS at FP8 M=256 N=K=4096 on MI355X).
  • True: framework-driven body via Pipeline4Wave under SchedulingStrategy.IDENTITY + minimal_barriers + omit_mma_set_prio. The framework consumes the 24-op cross-stage-rotation body verbatim (no CSP/double-buffer reorder). ~3% slower than the hand-written body on FP8 BM=128; gap is structurally tied to op ordering inside each mini-iter (pre-barrier ds_reads vs post-barrier).

SMEM/loader/MMA/output-store scaffolding is shared between both paths; only the prologue/main-loop/epilogue body differs.

Parameters:

  • ​a_layout (TensorLayout): Logical layout of a.
  • ​b_layout (TensorLayout): Logical layout of b.
  • ​c_layout (TensorLayout): Logical layout of c.
  • ​use_framework_schedule (Bool): When True, uses the framework-driven schedule body; otherwise emits the hand-written body.
  • ​num_splits (Int): Split-K factor (1 means no split).

Args: