Mojo struct
AMDMatmul
struct AMDMatmul[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None]
Pure TileTensor structured matmul for AMD GPUs.
Schedule-driven single-buffer pipeline. All data movement uses TileTensor — no LayoutTensor anywhere.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
accum_type
comptime accum_type = get_accum_type[a_type]()
BK
comptime BK = (config.block_tile_shape[2] * Int[Int](config.num_warp_k_partitions))
BM
comptime BM = config.block_tile_shape[0]
BN
comptime BN = config.block_tile_shape[1]
c_frag_size
comptime c_frag_size = ((AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_M * AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_N) // WARP_SIZE)
frag_size
comptime frag_size = ((AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_M * AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_K) // WARP_SIZE)
k_group_size
comptime k_group_size = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].simd_width // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].frag_size)
k_tile_size
comptime k_tile_size = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_K * AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].k_group_size)
MMA_K
comptime MMA_K = config.mma_shape[2]
MMA_M
comptime MMA_M = config.mma_shape[0]
MMA_N
comptime MMA_N = config.mma_shape[1]
num_k_mmas
comptime num_k_mmas = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WK // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_K)
num_k_tiles
comptime num_k_tiles = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WK // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].k_tile_size)
num_m_mmas
comptime num_m_mmas = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WM // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_M)
num_n_mmas
comptime num_n_mmas = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WN // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_N)
num_warps_k
comptime num_warps_k = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].BK // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WK)
num_warps_m
comptime num_warps_m = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].BM // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WM)
num_warps_n
comptime num_warps_n = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].BN // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WN)
simd_width
comptime simd_width = simd_width_of[a_type]()
WK
comptime WK = config.warp_tile_shape[2]
WM
comptime WM = config.warp_tile_shape[0]
WN
comptime WN = config.warp_tile_shape[1]
Methods
make_mma_swizzle
static make_mma_swizzle() -> Swizzle
Swizzle for blocked-product SMEM layout (LDS bank conflict avoidance).
The blocked-product layout stores k-tile elements in contiguous blocks. The MMA distribute reads these in col_major[MMA_M, WARP_SIZE/MMA_M] order, giving WARP_SIZE/MMA_M vector columns per block. The swizzle XORs enough bits to spread those column groups across LDS banks.
Unlike the ping-pong make_mma_swizzle (element-space for row-major SMEM with base/shift derived from fragment bytes), this operates in the vector-index space of each blocked-product chunk (base=0, shift=1).
Returns:
Swizzle: Swizzle for bank-conflict-free blocked-product LDS access.
run
static run[c_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout](c: TileTensor[c_type, c_layout, MutAnyOrigin], a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin])
TileTensor GEMM matching original kernel config exactly.
Uses StructuredMmaOp with per-k-tile load_frag/mma dispatch, original warp index order, and schedule-driven pipeline.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!