IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

AMD4WaveMatmul

struct AMD4WaveMatmul[a_type: DType, b_type: DType, c_type: DType, config: MatmulKernelConfig, /, enable_swizzle: Bool, elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None]

Hand-written 4-warp 2x2 inline-MMA matmul for AMD MI355X.

Line-by-line port of HipKittens FP8_4wave's matmul_device_1024 (BM=64) and matmul_device_2048 (BM=128). No declarative pipeline framework β€” explicit waits, explicit barriers, explicit register rotation matching the source's k+2 prefetch pattern.

Parameters​

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

accum_dtype​

comptime accum_dtype = get_accum_type[c_type]()

Accumulator dtype derived from c_type.

BK​

comptime BK = config.block_shape[Int(2)]

Workgroup K-tile size.

BM​

comptime BM = config.block_shape[Int(0)]

Workgroup M-tile size.

BN​

comptime BN = config.block_shape[Int(1)]

Workgroup N-tile size.

byte_swizzle​

comptime byte_swizzle = Optional(Swizzle(Int((add log2_floor((config.mma_shape[Int(2)] // Int(32))), 1)), log2_floor(Int((mul size_of[a_type](), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 78) and (eq config.mma_shape[Int(0)], 16) and (eq config.mma_shape[Int(2)], 128) else (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size())))) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 78) else Int((add log2_floor((Int((mul (Int(16) // size_of[a_type]()), 4)) // Int(2))), log2_floor(size_of[a_type]()))), Int(4))) if enable_swizzle else Optional()

Optional producer-side (byte-space) SMEM-store swizzle.

c_frag_size​

comptime c_frag_size = (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(1)])) // _resolve_warp_size())

Per-lane output-fragment width for one MMA.

half_BM​

comptime half_BM = (config.block_shape[Int(0)] // Int(2))

Half of BM β€” one M-subtile per SMEM stage.

half_BN​

comptime half_BN = (config.block_shape[Int(1)] // Int(2))

Half of BN β€” one N-subtile per SMEM stage.

in_type​

comptime in_type = a_type

Input element type (A and B share a type for FP8).

MMA_K​

comptime MMA_K = config.mma_shape[Int(2)]

Single-MMA K dimension.

MMA_M​

comptime MMA_M = config.mma_shape[Int(0)]

Single-MMA M dimension.

MMA_N​

comptime MMA_N = config.mma_shape[Int(1)]

Single-MMA N dimension.

mma_swizzle​

comptime mma_swizzle = Optional(Swizzle(Int((add log2_floor((config.mma_shape[Int(2)] // Int(32))), 1)), log2_floor(Int((mul size_of[a_type](), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 78) and (eq config.mma_shape[Int(0)], 16) and (eq config.mma_shape[Int(2)], 128) else (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size())))), Int(4))) if enable_swizzle else Optional()

Optional consumer-side MMA swizzle, populated when swizzling is on.

mma_tile_m​

comptime mma_tile_m = (config.warp_shape[Int(0)] // Int(2))

Per-quadrant M-tile size consumed by an MMA load.

mma_tile_n​

comptime mma_tile_n = (config.warp_shape[Int(1)] // Int(2))

Per-quadrant N-tile size consumed by an MMA load.

num_k_mmas​

comptime num_k_mmas = (config.block_shape[Int(2)] // config.mma_shape[Int(2)])

Number of MMAs per warp in the K dimension.

num_m_mmas​

comptime num_m_mmas = (config.warp_shape[Int(0)] // config.mma_shape[Int(0)])

Number of MMAs per warp in the M dimension.

num_n_mmas​

comptime num_n_mmas = (config.warp_shape[Int(1)] // config.mma_shape[Int(1)])

Number of MMAs per warp in the N dimension.

num_warps_m​

comptime num_warps_m = (config.block_shape[Int(0)] // config.warp_shape[Int(0)])

Number of warps in the M dimension of the workgroup grid.

num_warps_n​

comptime num_warps_n = (config.block_shape[Int(1)] // config.warp_shape[Int(1)])

Number of warps in the N dimension of the workgroup grid.

quadrant_m_mmas​

comptime quadrant_m_mmas = ((config.warp_shape[Int(0)] // config.mma_shape[Int(0)]) // Int(2))

Number of MMAs per quadrant in the M dimension.

quadrant_n_mmas​

comptime quadrant_n_mmas = ((config.warp_shape[Int(1)] // config.mma_shape[Int(1)]) // Int(2))

Number of MMAs per quadrant in the N dimension.

simd_width​

comptime simd_width = (Int(16) // size_of[a_type]())

SIMD lane width matched to AMD buffer_load_lds's 16-byte transaction.

Target-independent (unlike simd_width_of[in_type](), which would return the host's SIMD width when this struct is instantiated from a comptime context running on the CPU host). validate_config asserts that size_of[in_type] divides 16 evenly; for any dtype that does (FP8/BF16/FP16/FP32) this gives the same value as simd_width_of inside an AMDGPU kernel.

SMEM_BYTES​

comptime SMEM_BYTES = (Int((add (mul config.block_shape[Int(2)], config.block_shape[Int(0)], 2), (mul config.block_shape[Int(2)], config.block_shape[Int(1)], 2))) * size_of[a_type]())

SMEM footprint per workgroup (bytes), derived from BM/BN/BK/in_type.

total_warps​

comptime total_warps = ((config.block_shape[Int(0)] // config.warp_shape[Int(0)]) * (config.block_shape[Int(1)] // config.warp_shape[Int(1)]))

Total warps per workgroup (must be 4 for this kernel).

VMCNT_PER_LOAD_A​

comptime VMCNT_PER_LOAD_A = AMD4WaveMatmul._build_geometry().vm_per_load_a

Global-load vmcnt cost per A prefetch (from KernelGeometry).

VMCNT_PER_LOAD_B​

comptime VMCNT_PER_LOAD_B = AMD4WaveMatmul._build_geometry().vm_per_load_b

Global-load vmcnt cost per B prefetch (from KernelGeometry).

WM​

comptime WM = config.warp_shape[Int(0)]

Per-warp M-tile size.

WN​

comptime WN = config.warp_shape[Int(1)]

Per-warp N-tile size.

Methods​

is_valid_config​

static def is_valid_config() -> Bool

Returns whether the kernel's tile shapes are a viable config.

Pure predicate β€” does not raise. Use this from autotune drivers and dispatcher fallbacks to filter out impossible (BM, BN, BK, dtype) combinations before instantiating the kernel. The full per-check set is in validate_config; this is its non-throwing counterpart.

Returns:

Bool: True if every structural and resource invariant holds.

validate_config​

static def validate_config()

Asserts that the kernel's tile shapes meet 4-wave invariants.

Throws (via comptime assert) on the first failing invariant, with a check-specific message. Called from run so any kernel instantiation that compiles is guaranteed valid. For non- throwing tests (autotune sweeps, dispatcher pre-flight), use is_valid_config instead.

run​

static def run[a_layout: TensorLayout, b_layout: TensorLayout, c_layout: TensorLayout, *, num_splits: Int = Int(1)](a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin], c: TileTensor[c_type, c_layout, MutAnyOrigin])

Runs the 4-wave GEMM kernel for one workgroup tile.

Emits the framework-driven body via Pipeline4Wave under SchedulingStrategy.IDENTITY + minimal_barriers + omit_mma_set_prio. The framework consumes the 24-op cross-stage-rotation body verbatim (no CSP/double-buffer reorder).

Parameters:

  • ​a_layout (TensorLayout): Logical layout of a.
  • ​b_layout (TensorLayout): Logical layout of b.
  • ​c_layout (TensorLayout): Logical layout of c.
  • ​num_splits (Int): Split-K factor (1 means no split).

Args:

run_conv2d​

static def run_conv2d[conv_config: Conv2DKernelConfig, a_layout: TensorLayout, b_layout: TensorLayout, c_layout: TensorLayout, has_residual: Bool = False, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None](a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin], c: TileTensor[c_type, c_layout, MutAnyOrigin], source_ptr: UnsafePointer[Scalar[c_type], ImmutAnyOrigin], source_row_stride: Int, beta: Float32)

Runs the 4-wave kernel as a 2D convolution via implicit-GEMM.

Sibling of run() (the matmul entry point); both share the 4-warp 2x2 quadrant layout, the same MFMA shapes, and the same software-pipeline schedule β€” only the A-operand loader differs. run() uses TileLoaderLDS (linear MK source); this method uses TileLoaderLDSIm2col, which materializes the A operand from a 4D NHWC input via in-line im2col addressing.

Optional in-kernel residual prefetch via has_residual. When True, the kernel bulk-prefetches source (a 2D [M, C_out]-aliased view of an NHWC residual buffer) into VGPRs at the start of the epilogue. By the time the FMA-and-store loop runs, all 32 per-lane residual loads are in flight in parallel β€” replacing the per-store global_load β†’ wait β†’ store staircase that costs ~24% on memory-bound shapes. The launcher passes the residual pointer, row stride, and beta scale; the kernel applies out = mma + beta * residual in Float32 and casts back to c_type for the store.

Parameters:

  • ​conv_config (Conv2DKernelConfig): Conv geometry (filter shape, stride, dilation, pad, input H/W, runtime-HW flag).
  • ​a_layout (TensorLayout): Logical layout of a (4D NHWC).
  • ​b_layout (TensorLayout): Logical layout of b (2D [C_out, K] filter).
  • ​c_layout (TensorLayout): Logical layout of c (2D [M, C_out] output).
  • ​has_residual (Bool): When True, prefetch + FMA in source * beta during the epilogue. When False, the residual args are unused and the epilogue is identical to the no-residual kernel (dead-code-eliminated by the compiler).
  • ​elementwise_compute_lambda_fn (Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width]]): Optional pre-residual fused compute lambda. Fires on the post-cast c_type MMA output BEFORE the residual FMA, matching the SM100 D = lambda(Conv(A,B)) + beta * C ordering. Use for bias / ReLU / SiLU / GELU fusion. Signature (IndexList[2], SIMD) capturing -> SIMD. When unset, the comptime branch dead-code-eliminates.

Args: