Mojo function
gemm_kernel_rdna
gemm_kernel_rdna[c_type: DType, a_type: DType, b_type: DType, c_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout, transpose_b: Bool = True, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, s_type: DType = get_accum_type[c_type](), BLOCK_K: Int = 16, BLOCK_M: Int = 128, BLOCK_N: Int = 128, WARPS_M: Int = 8, WARPS_N: Int = 2, WARP_TILE_M: Int = 1, WARP_TILE_N: Int = 4](c: TileTensor[c_type, c_layout, MutAnyOrigin], a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin], m: Int, n: Int, k: Int)
GEMM kernel for AMD RDNA GPUs.
On RDNA 3+ (gfx11xx/gfx12xx), uses 16x16x16 WMMA instructions with shared memory tiling, double-buffered shared memory, and coalesced loads. On older RDNA (gfx10xx), falls back to a per-thread naive matmul.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!