Mojo function
structured_4wave_matmul
structured_4wave_matmul[a_type: DType, b_type: DType, c_type: DType, //, enable_swizzle: Bool = True, block_m_override: Int = 0, block_n_override: Int = 0, block_k_override: Int = 0, dump_asm_path: StringSlice[StaticConstantOrigin] = StringSlice(""), elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](a: TileTensor[a_type, address_space=a.address_space, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[b_type, address_space=b.address_space, linear_idx_type=b.linear_idx_type, element_size=b.element_size], c: TileTensor[c_type, address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size], ctx: DeviceContext)
Canonical 4-wave matmul launcher (mirror of amd_ping_pong_matmul).
Single black-box entry point for all dtypes (FP8 + BF16 + FP16) and
all matmul shape regimes. Internally dispatches on (dtype, M) to
pick the right (BM, BN, BK, MMA shape) config, and always routes
through the framework-scheduled body of AMD4WaveMatmul.run.
Production callers go through a higher-level dispatcher (e.g.
AMDMatmul) that knows the shape, dtype, and which other kernels
are available β it should set block_{m,n,k}_override explicitly
based on its own policy. The internal auto-pick below is a
convenience default for direct/ad-hoc/benchmark callers; do not
rely on it from a production dispatcher.
Recommended tile shapes (measured on MI355X, bf16 N=K=8192):
M = 1 : use a dedicated GEMV kernel (linalg/gemv.mojo),
not 4-wave β single-row matvec has its own
hardware-aligned dispatch.
2 β€ M β€ 64 : use amd_4wave_split_k_matmul with
num_splits=4 and BK=128 (the plain kernel
under-fills the GPU with too few M-blocks).
64 < M β€ 512 : BM=128, BN=128, BK=128 (~1.2Γ hipBLASLt at
M=128; within ~15% at M=512).
M β₯ 1024 : BM=128, BN=128, BK=64 (BK=128 register-
pressure-limits ILP at large M; BK=64 leaves
~88 VGPRs of scheduling headroom).
These were chosen on raw matmul throughput. The dispatcher may
legitimately route through 4-wave even when raw matmul is a bit
slower than a non-fusable vendor BLAS, because
elementwise_lambda_fn can fuse epilogues (bias, GELU, scale,
residual) and save a separate elementwise kernel launch.
FP8 keeps the original auto-pick (BM=64 for M β€ 512, BM=128
otherwise) and BK=128 / mma_shape=(16,16,128).
Parameters:
- βa_type (
DType): Element type ofa. - βb_type (
DType): Element type ofb. - βc_type (
DType): Element type ofc. - βenable_swizzle (
Bool): Enable LDS bank-conflict avoidance. - βblock_m_override (
Int): If > 0, force BM to this value (must be 64 or 128). Set this andblock_n_overridetogether to pin the tile shape from a dispatcher. - βblock_n_override (
Int): If > 0, force BN to this value (must be 64, 128, or 256). Default 0 uses BM=BN. - βblock_k_override (
Int): If > 0, force BK to this value. Must be a multiple of MMA_K (32 for bf16/fp16, 128 for FP8) and must satisfy K % (2*BK) == 0. Valid: 32, 64, 128 for bf16/fp16; 128 for FP8. - βdump_asm_path (
StringSlice[StaticConstantOrigin]): If non-empty, dumps the compiled GCN assembly to the given file path. Only used for ASM-level diff-debugging. - βelementwise_lambda_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional fused epilogue. When set, the kernel's output store calls the lambda with global(m_global, n_global)coords and a SIMD-of-c_frag_sizeinstead of writing tocdirectly.
Args:
- βa (
TileTensor[a_type, address_space=a.address_space, linear_idx_type=a.linear_idx_type, element_size=a.element_size]): Input tile-tensor for A. - βb (
TileTensor[b_type, address_space=b.address_space, linear_idx_type=b.linear_idx_type, element_size=b.element_size]): Input tile-tensor for B. - βc (
TileTensor[c_type, address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size]): Output tile-tensor for C. - βctx (
DeviceContext): Device context used to enqueue the kernel.
Raises:
An error if device enqueue fails.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!