Skip to main content

Mojo function

structured_4wave_matmul

structured_4wave_matmul[a_type: DType, b_type: DType, c_type: DType, //, enable_swizzle: Bool = True, block_m_override: Int = 0, block_n_override: Int = 0, block_k_override: Int = 0, dump_asm_path: StringSlice[StaticConstantOrigin] = StringSlice(""), elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](a: TileTensor[a_type, address_space=a.address_space, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[b_type, address_space=b.address_space, linear_idx_type=b.linear_idx_type, element_size=b.element_size], c: TileTensor[c_type, address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size], ctx: DeviceContext)

Canonical 4-wave matmul launcher (mirror of amd_ping_pong_matmul).

Single black-box entry point for all dtypes (FP8 + BF16 + FP16) and all matmul shape regimes. Internally dispatches on (dtype, M) to pick the right (BM, BN, BK, MMA shape) config, and always routes through the framework-scheduled body of AMD4WaveMatmul.run.

Production callers go through a higher-level dispatcher (e.g. AMDMatmul) that knows the shape, dtype, and which other kernels are available β€” it should set block_{m,n,k}_override explicitly based on its own policy. The internal auto-pick below is a convenience default for direct/ad-hoc/benchmark callers; do not rely on it from a production dispatcher.

Recommended tile shapes (measured on MI355X, bf16 N=K=8192):

M = 1 : use a dedicated GEMV kernel (linalg/gemv.mojo), not 4-wave β€” single-row matvec has its own hardware-aligned dispatch. 2 ≀ M ≀ 64 : use amd_4wave_split_k_matmul with num_splits=4 and BK=128 (the plain kernel under-fills the GPU with too few M-blocks). 64 < M ≀ 512 : BM=128, BN=128, BK=128 (~1.2Γ— hipBLASLt at M=128; within ~15% at M=512). M β‰₯ 1024 : BM=128, BN=128, BK=64 (BK=128 register- pressure-limits ILP at large M; BK=64 leaves ~88 VGPRs of scheduling headroom).

These were chosen on raw matmul throughput. The dispatcher may legitimately route through 4-wave even when raw matmul is a bit slower than a non-fusable vendor BLAS, because elementwise_lambda_fn can fuse epilogues (bias, GELU, scale, residual) and save a separate elementwise kernel launch.

FP8 keeps the original auto-pick (BM=64 for M ≀ 512, BM=128 otherwise) and BK=128 / mma_shape=(16,16,128).

Parameters:

  • ​a_type (DType): Element type of a.
  • ​b_type (DType): Element type of b.
  • ​c_type (DType): Element type of c.
  • ​enable_swizzle (Bool): Enable LDS bank-conflict avoidance.
  • ​block_m_override (Int): If > 0, force BM to this value (must be 64 or 128). Set this and block_n_override together to pin the tile shape from a dispatcher.
  • ​block_n_override (Int): If > 0, force BN to this value (must be 64, 128, or 256). Default 0 uses BM=BN.
  • ​block_k_override (Int): If > 0, force BK to this value. Must be a multiple of MMA_K (32 for bf16/fp16, 128 for FP8) and must satisfy K % (2*BK) == 0. Valid: 32, 64, 128 for bf16/fp16; 128 for FP8.
  • ​dump_asm_path (StringSlice[StaticConstantOrigin]): If non-empty, dumps the compiled GCN assembly to the given file path. Only used for ASM-level diff-debugging.
  • ​elementwise_lambda_fn (Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional fused epilogue. When set, the kernel's output store calls the lambda with global (m_global, n_global) coords and a SIMD-of-c_frag_size instead of writing to c directly.

Args:

Raises:

An error if device enqueue fails.