IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

AMDMatmul

struct AMDMatmul[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None]

Pure TileTensor structured matmul for AMD GPUs.

Schedule-driven single-buffer pipeline. All data movement uses TileTensor β€” no LayoutTensor anywhere.

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

accum_type​

comptime accum_type = get_accum_type[a_type]()

BK​

comptime BK = (config.block_tile_shape[Int(2)] * config)

BM​

comptime BM = config.block_tile_shape[Int(0)]

BN​

comptime BN = config.block_tile_shape[Int(1)]

c_frag_size​

comptime c_frag_size = (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(1)])) // _resolve_warp_size())

frag_size​

comptime frag_size = (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size())

k_group_size​

comptime k_group_size = (simd_width_of[a_type]() // (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size()))

k_tile_size​

comptime k_tile_size = (config.mma_shape[Int(2)] * (simd_width_of[a_type]() // (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size())))

MMA_K​

comptime MMA_K = config.mma_shape[Int(2)]

MMA_M​

comptime MMA_M = config.mma_shape[Int(0)]

MMA_N​

comptime MMA_N = config.mma_shape[Int(1)]

num_k_mmas​

comptime num_k_mmas = (config.warp_tile_shape[Int(2)] // config.mma_shape[Int(2)])

num_k_tiles​

comptime num_k_tiles = (config.warp_tile_shape[Int(2)] // Int((mul config.mma_shape[Int(2)], (simd_width_of[a_type]() // (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size())))))

num_m_mmas​

comptime num_m_mmas = (config.warp_tile_shape[Int(0)] // config.mma_shape[Int(0)])

num_n_mmas​

comptime num_n_mmas = (config.warp_tile_shape[Int(1)] // config.mma_shape[Int(1)])

num_warps_k​

comptime num_warps_k = (Int((mul config.block_tile_shape[Int(2)], config.num_warp_k_partitions)) // config.warp_tile_shape[Int(2)])

num_warps_m​

comptime num_warps_m = (config.block_tile_shape[Int(0)] // config.warp_tile_shape[Int(0)])

num_warps_n​

comptime num_warps_n = (config.block_tile_shape[Int(1)] // config.warp_tile_shape[Int(1)])

simd_width​

comptime simd_width = simd_width_of[a_type]()

WK​

comptime WK = config.warp_tile_shape[Int(2)]

WM​

comptime WM = config.warp_tile_shape[Int(0)]

WN​

comptime WN = config.warp_tile_shape[Int(1)]

Methods​

make_mma_swizzle​

static def make_mma_swizzle() -> Swizzle

Swizzle for blocked-product SMEM layout (LDS bank conflict avoidance).

The blocked-product layout stores k-tile elements in contiguous blocks. The MMA distribute reads these in col_major[MMA_M, WARP_SIZE/MMA_M] order, giving WARP_SIZE/MMA_M vector columns per block. The swizzle XORs enough bits to spread those column groups across LDS banks.

Unlike the ping-pong make_mma_swizzle (element-space for row-major SMEM with base/shift derived from fragment bytes), this operates in the vector-index space of each blocked-product chunk (base=0, shift=1).

Returns:

Swizzle: Swizzle for bank-conflict-free blocked-product LDS access.

run​

static def run[c_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout](c: TileTensor[c_type, c_layout, MutAnyOrigin], a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin])

TileTensor GEMM matching original kernel config exactly.

Uses StructuredMmaOp with per-k-tile load_frag/mma dispatch, original warp index order, and schedule-driven pipeline.