IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

structured_4wave_matmul

def structured_4wave_matmul[a_type: DType, b_type: DType, c_type: DType, //, enable_swizzle: Bool = True, block_m_override: Int = 0, block_n_override: Int = 0, block_k_override: Int = 0, dump_asm_path: StringSlice[StaticConstantOrigin] = StringSlice(""), elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](a: TileTensor[a_type, address_space=a.address_space, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[b_type, address_space=b.address_space, linear_idx_type=b.linear_idx_type, element_size=b.element_size], c: TileTensor[c_type, address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size], ctx: DeviceContext)

Canonical 4-wave matmul launcher (mirror of amd_ping_pong_matmul).

Single black-box entry point for all dtypes (FP8 + BF16 + FP16) and all matmul shape regimes. Internally dispatches on (dtype, M) to pick the right (BM, BN, BK, MMA shape) config, and always routes through the framework-scheduled body of AMD4WaveMatmul.run.

Production callers go through a higher-level dispatcher (e.g. AMDMatmul) that knows the shape, dtype, and which other kernels are available β€” it should set block_{m,n,k}_override explicitly based on its own policy. The internal auto-pick below is a convenience default for direct/ad-hoc/benchmark callers; do not rely on it from a production dispatcher.

Recommended tile shapes (measured on MI355X, bf16 N=K=8192):

M = 1 : use a dedicated GEMV kernel (linalg/gemv.mojo), not 4-wave β€” single-row matvec has its own hardware-aligned dispatch. 2 ≀ M ≀ 64 : use amd_4wave_split_k_matmul with num_splits=4 and BK=128 (the plain kernel under-fills the GPU with too few M-blocks). 64 < M ≀ 512 : BM=128, BN=128, BK=128 (~1.2Γ— hipBLASLt at M=128; within ~15% at M=512). M β‰₯ 1024 : BM=128, BN=128, BK=64 (BK=128 register- pressure-limits ILP at large M; BK=64 leaves ~88 VGPRs of scheduling headroom).

These were chosen on raw matmul throughput. The dispatcher may legitimately route through 4-wave even when raw matmul is a bit slower than a non-fusable vendor BLAS, because elementwise_lambda_fn can fuse epilogues (bias, GELU, scale, residual) and save a separate elementwise kernel launch.

FP8 keeps the original auto-pick (BM=64 for M ≀ 512, BM=128 otherwise) and BK=128 / mma_shape=(16,16,128).

Parameters:

  • ​a_type (DType): Element type of a.
  • ​b_type (DType): Element type of b.
  • ​c_type (DType): Element type of c.
  • ​enable_swizzle (Bool): Enable LDS bank-conflict avoidance.
  • ​block_m_override (Int): If > 0, force BM to this value (must be 64 or 128). Set this and block_n_override together to pin the tile shape from a dispatcher.
  • ​block_n_override (Int): If > 0, force BN to this value (must be 64, 128, or 256). Default 0 uses BM=BN.
  • ​block_k_override (Int): If > 0, force BK to this value. Must be a multiple of MMA_K (32 for bf16/fp16, 128 for FP8) and must satisfy K % (2*BK) == 0. Valid: 32, 64, 128 for bf16/fp16; 128 for FP8.
  • ​dump_asm_path (StringSlice[StaticConstantOrigin]): If non-empty, dumps the compiled GCN assembly to the given file path. Only used for ASM-level diff-debugging.
  • ​elementwise_lambda_fn (Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional fused epilogue. When set, the kernel's output store calls the lambda with global (m_global, n_global) coords and a SIMD-of-c_frag_size instead of writing to c directly.

Args:

Raises:

An error if device enqueue fails.