Skip to main content

Mojo struct

AMDMatmul

struct AMDMatmul[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None]

Pure TileTensor structured matmul for AMD GPUs.

Schedule-driven single-buffer pipeline. All data movement uses TileTensor — no LayoutTensor anywhere.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

accum_type

comptime accum_type = get_accum_type[a_type]()

BK

comptime BK = (config.block_tile_shape[2] * Int[Int](config.num_warp_k_partitions))

BM

comptime BM = config.block_tile_shape[0]

BN

comptime BN = config.block_tile_shape[1]

c_frag_size

comptime c_frag_size = ((AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_M * AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_N) // WARP_SIZE)

frag_size

comptime frag_size = ((AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_M * AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_K) // WARP_SIZE)

k_group_size

comptime k_group_size = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].simd_width // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].frag_size)

k_tile_size

comptime k_tile_size = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_K * AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].k_group_size)

MMA_K

comptime MMA_K = config.mma_shape[2]

MMA_M

comptime MMA_M = config.mma_shape[0]

MMA_N

comptime MMA_N = config.mma_shape[1]

num_k_mmas

comptime num_k_mmas = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WK // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_K)

num_k_tiles

comptime num_k_tiles = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WK // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].k_tile_size)

num_m_mmas

comptime num_m_mmas = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WM // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_M)

num_n_mmas

comptime num_n_mmas = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WN // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].MMA_N)

num_warps_k

comptime num_warps_k = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].BK // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WK)

num_warps_m

comptime num_warps_m = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].BM // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WM)

num_warps_n

comptime num_warps_n = (AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].BN // AMDMatmul[a_type, b_type, c_type, transpose_b, config, elementwise_lambda_fn].WN)

simd_width

comptime simd_width = simd_width_of[a_type]()

WK

comptime WK = config.warp_tile_shape[2]

WM

comptime WM = config.warp_tile_shape[0]

WN

comptime WN = config.warp_tile_shape[1]

Methods

make_mma_swizzle

static make_mma_swizzle() -> Swizzle

Swizzle for blocked-product SMEM layout (LDS bank conflict avoidance).

The blocked-product layout stores k-tile elements in contiguous blocks. The MMA distribute reads these in col_major[MMA_M, WARP_SIZE/MMA_M] order, giving WARP_SIZE/MMA_M vector columns per block. The swizzle XORs enough bits to spread those column groups across LDS banks.

Unlike the ping-pong make_mma_swizzle (element-space for row-major SMEM with base/shift derived from fragment bytes), this operates in the vector-index space of each blocked-product chunk (base=0, shift=1).

Returns:

Swizzle: Swizzle for bank-conflict-free blocked-product LDS access.

run

static run[c_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout](c: TileTensor[c_type, c_layout, MutAnyOrigin], a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin])

TileTensor GEMM matching original kernel config exactly.

Uses StructuredMmaOp with per-k-tile load_frag/mma dispatch, original warp index order, and schedule-driven pipeline.

Was this page helpful?