Mojo struct
AMD4WaveMatmul
struct AMD4WaveMatmul[a_type: DType, b_type: DType, c_type: DType, config: KernelConfig, /, enable_swizzle: Bool, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None]
Hand-written 4-warp 2x2 inline-MMA matmul for AMD MI355X.
Line-by-line port of HipKittens FP8_4wave's matmul_device_1024
(BM=64) and matmul_device_2048 (BM=128). No declarative pipeline
framework β explicit waits, explicit barriers, explicit register
rotation matching the source's k+2 prefetch pattern.
Parametersβ
- βa_type (
DType): Input A element type. - βb_type (
DType): Input B element type. - βc_type (
DType): Output C element type. - βconfig (
KernelConfig): KernelConfig with block/warp/mma shapes. - βenable_swizzle (
Bool): Enable LDS bank conflict avoidance. - βelementwise_lambda_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional epilogue.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
accum_dtypeβ
comptime accum_dtype = get_accum_type[c_type]()
Accumulator dtype derived from c_type.
BKβ
comptime BK = config.block_shape[2]
Workgroup K-tile size.
BMβ
comptime BM = config.block_shape[0]
Workgroup M-tile size.
BNβ
comptime BN = config.block_shape[1]
Workgroup N-tile size.
byte_swizzleβ
comptime byte_swizzle = Optional(Swizzle((log2_floor((AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K // 32)) + 1), log2_floor((16 if AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())) if AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() else (log2_floor(((4 * AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].simd_width) // 2)) + log2_floor(size_of[AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())), 4)) if enable_swizzle else Optional()
Optional producer-side (byte-space) SMEM-store swizzle.
c_frag_sizeβ
comptime c_frag_size = ((AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N) // WARP_SIZE)
Per-lane output-fragment width for one MMA.
half_BMβ
comptime half_BM = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BM // 2)
Half of BM β one M-subtile per SMEM stage.
half_BNβ
comptime half_BN = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BN // 2)
Half of BN β one N-subtile per SMEM stage.
in_typeβ
comptime in_type = a_type
Input element type (A and B share a type for FP8).
MMA_Kβ
comptime MMA_K = config.mma_shape[2]
Single-MMA K dimension.
MMA_Mβ
comptime MMA_M = config.mma_shape[0]
Single-MMA M dimension.
MMA_Nβ
comptime MMA_N = config.mma_shape[1]
Single-MMA N dimension.
mma_swizzleβ
comptime mma_swizzle = Optional(Swizzle((log2_floor((AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K // 32)) + 1), log2_floor((16 if AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type.is_float8() and (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M == 16) and (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K == 128) else ((AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M * AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K) // WARP_SIZE) * size_of[AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())), 4)) if enable_swizzle else Optional()
Optional consumer-side MMA swizzle, populated when swizzling is on.
mma_tile_mβ
comptime mma_tile_m = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM // 2)
Per-quadrant M-tile size consumed by an MMA load.
mma_tile_nβ
comptime mma_tile_n = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN // 2)
Per-quadrant N-tile size consumed by an MMA load.
num_k_mmasβ
comptime num_k_mmas = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BK // AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_K)
Number of MMAs per warp in the K dimension.
num_m_mmasβ
comptime num_m_mmas = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM // AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_M)
Number of MMAs per warp in the M dimension.
num_n_mmasβ
comptime num_n_mmas = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN // AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].MMA_N)
Number of MMAs per warp in the N dimension.
num_warps_mβ
comptime num_warps_m = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BM // AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WM)
Number of warps in the M dimension of the workgroup grid.
num_warps_nβ
comptime num_warps_n = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BN // AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].WN)
Number of warps in the N dimension of the workgroup grid.
quadrant_m_mmasβ
comptime quadrant_m_mmas = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_m_mmas // 2)
Number of MMAs per quadrant in the M dimension.
quadrant_n_mmasβ
comptime quadrant_n_mmas = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_n_mmas // 2)
Number of MMAs per quadrant in the N dimension.
simd_widthβ
comptime simd_width = (16 // size_of[AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())
SIMD lane width matched to AMD buffer_load_lds's 16-byte transaction.
Target-independent (unlike simd_width_of[in_type](), which would
return the host's SIMD width when this struct is instantiated from a
comptime context running on the CPU host). validate_config asserts
that size_of[in_type] divides 16 evenly; for any dtype that does
(FP8/BF16/FP16/FP32) this gives the same value as simd_width_of
inside an AMDGPU kernel.
SMEM_BYTESβ
comptime SMEM_BYTES = (((2 * (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BM + AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BN)) * AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].BK) * size_of[AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].in_type]())
SMEM footprint per workgroup (bytes), derived from BM/BN/BK/in_type.
total_warpsβ
comptime total_warps = (AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_warps_m * AMD4WaveMatmul[a_type, b_type, c_type, config, enable_swizzle, elementwise_lambda_fn].num_warps_n)
Total warps per workgroup (must be 4 for this kernel).
VMCNT_PER_LOAD_Aβ
comptime VMCNT_PER_LOAD_A = AMD4WaveMatmul._build_geometry().vm_per_load_a
Global-load vmcnt cost per A prefetch (from KernelGeometry).
VMCNT_PER_LOAD_Bβ
comptime VMCNT_PER_LOAD_B = AMD4WaveMatmul._build_geometry().vm_per_load_b
Global-load vmcnt cost per B prefetch (from KernelGeometry).
WMβ
comptime WM = config.warp_shape[0]
Per-warp M-tile size.
WNβ
comptime WN = config.warp_shape[1]
Per-warp N-tile size.
Methodsβ
is_valid_configβ
static is_valid_config() -> Bool
Returns whether the kernel's tile shapes are a viable config.
Pure predicate β does not raise. Use this from autotune drivers
and dispatcher fallbacks to filter out impossible (BM, BN, BK,
dtype) combinations before instantiating the kernel. The full
per-check set is in validate_config; this is its non-throwing
counterpart.
Returns:
Bool: True if every structural and resource invariant holds.
validate_configβ
static validate_config()
Asserts that the kernel's tile shapes meet 4-wave invariants.
Throws (via comptime assert) on the first failing invariant,
with a check-specific message. Called from run so any kernel
instantiation that compiles is guaranteed valid. For non-
throwing tests (autotune sweeps, dispatcher pre-flight), use
is_valid_config instead.
runβ
static run[a_layout: TensorLayout, b_layout: TensorLayout, c_layout: TensorLayout, *, use_framework_schedule: Bool = False, num_splits: Int = 1](a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin], c: TileTensor[c_type, c_layout, MutAnyOrigin])
Runs the 4-wave GEMM kernel for one workgroup tile.
Body strategy is selected at comptime via use_framework_schedule:
- False (default): hand-written
_run_iterbody. Explicit waits, barriers, and cross-stage register rotation matching HipKittens' FP8_4wave reference. Currently the perf champion (~516 TFLOPS at FP8 M=256 N=K=4096 on MI355X). - True: framework-driven body via
Pipeline4WaveunderSchedulingStrategy.IDENTITY+minimal_barriers+omit_mma_set_prio. The framework consumes the 24-op cross-stage-rotation body verbatim (no CSP/double-buffer reorder). ~3% slower than the hand-written body on FP8 BM=128; gap is structurally tied to op ordering inside each mini-iter (pre-barrier ds_reads vs post-barrier).
SMEM/loader/MMA/output-store scaffolding is shared between both paths; only the prologue/main-loop/epilogue body differs.
Parameters:
- βa_layout (
TensorLayout): Logical layout ofa. - βb_layout (
TensorLayout): Logical layout ofb. - βc_layout (
TensorLayout): Logical layout ofc. - βuse_framework_schedule (
Bool): When True, uses the framework-driven schedule body; otherwise emits the hand-written body. - βnum_splits (
Int): Split-K factor (1 means no split).
Args:
- βa (
TileTensor[a_type, a_layout, ImmutAnyOrigin]): Input tile-tensor for A. - βb (
TileTensor[b_type, b_layout, ImmutAnyOrigin]): Input tile-tensor for B. - βc (
TileTensor[c_type, c_layout, MutAnyOrigin]): Output tile-tensor for C (or workspace for split-K).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!