For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
AMDMatmul
struct AMDMatmul[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None]
Pure TileTensor structured matmul for AMD GPUs.
Schedule-driven single-buffer pipeline. All data movement uses TileTensor β no LayoutTensor anywhere.
Implemented traitsβ
comptime membersβ
accum_typeβ
comptime accum_type = get_accum_type[a_type]()
BKβ
comptime BK = (config.block_tile_shape[Int(2)] * config)
BMβ
comptime BM = config.block_tile_shape[Int(0)]
BNβ
comptime BN = config.block_tile_shape[Int(1)]
c_frag_sizeβ
comptime c_frag_size = (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(1)])) // _resolve_warp_size())
frag_sizeβ
comptime frag_size = (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size())
k_group_sizeβ
comptime k_group_size = (simd_width_of[a_type]() // (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size()))
k_tile_sizeβ
comptime k_tile_size = (config.mma_shape[Int(2)] * (simd_width_of[a_type]() // (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size())))
MMA_Kβ
comptime MMA_K = config.mma_shape[Int(2)]
MMA_Mβ
comptime MMA_M = config.mma_shape[Int(0)]
MMA_Nβ
comptime MMA_N = config.mma_shape[Int(1)]
num_k_mmasβ
comptime num_k_mmas = (config.warp_tile_shape[Int(2)] // config.mma_shape[Int(2)])
num_k_tilesβ
comptime num_k_tiles = (config.warp_tile_shape[Int(2)] // Int((mul config.mma_shape[Int(2)], (simd_width_of[a_type]() // (Int((mul config.mma_shape[Int(0)], config.mma_shape[Int(2)])) // _resolve_warp_size())))))
num_m_mmasβ
comptime num_m_mmas = (config.warp_tile_shape[Int(0)] // config.mma_shape[Int(0)])
num_n_mmasβ
comptime num_n_mmas = (config.warp_tile_shape[Int(1)] // config.mma_shape[Int(1)])
num_warps_kβ
comptime num_warps_k = (Int((mul config.block_tile_shape[Int(2)], config.num_warp_k_partitions)) // config.warp_tile_shape[Int(2)])
num_warps_mβ
comptime num_warps_m = (config.block_tile_shape[Int(0)] // config.warp_tile_shape[Int(0)])
num_warps_nβ
comptime num_warps_n = (config.block_tile_shape[Int(1)] // config.warp_tile_shape[Int(1)])
simd_widthβ
comptime simd_width = simd_width_of[a_type]()
WKβ
comptime WK = config.warp_tile_shape[Int(2)]
WMβ
comptime WM = config.warp_tile_shape[Int(0)]
WNβ
comptime WN = config.warp_tile_shape[Int(1)]
Methodsβ
make_mma_swizzleβ
static def make_mma_swizzle() -> Swizzle
Swizzle for blocked-product SMEM layout (LDS bank conflict avoidance).
The blocked-product layout stores k-tile elements in contiguous blocks. The MMA distribute reads these in col_major[MMA_M, WARP_SIZE/MMA_M] order, giving WARP_SIZE/MMA_M vector columns per block. The swizzle XORs enough bits to spread those column groups across LDS banks.
Unlike the ping-pong make_mma_swizzle (element-space for row-major SMEM with base/shift derived from fragment bytes), this operates in the vector-index space of each blocked-product chunk (base=0, shift=1).
Returns:
Swizzle: Swizzle for bank-conflict-free blocked-product LDS access.
runβ
static def run[c_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout](c: TileTensor[c_type, c_layout, MutAnyOrigin], a: TileTensor[a_type, a_layout, ImmutAnyOrigin], b: TileTensor[b_type, b_layout, ImmutAnyOrigin])
TileTensor GEMM matching original kernel config exactly.
Uses StructuredMmaOp with per-k-tile load_frag/mma dispatch, original warp index order, and schedule-driven pipeline.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!