For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
AmdTileOperator
struct AmdTileOperator[InType: DType, OutType: DType, warp_block_layout_a: Layout, warp_block_layout_b: Layout, mma_shape: IndexList[Int(3)], swizzle: Optional[Swizzle] = None, transpose_b: Bool = True]
Manages tensor core operations for matrix multiplication on AMD GPUs.
This operator handles loading matrix fragments from shared memory to registers and performing matrix multiply-accumulate operations using tensor cores.
Requirements: - warp_block_layout_a.shape[0] must be divisible by mma_shape[0] - warp_block_layout_b.shape[0] must be divisible by mma_shape[1] - warp_block_layout_a.shape[1] must be divisible by mma_shape[2] - warp_block_layout_b.shape[1] must be divisible by mma_shape[2] - The K dimension must align such that num_k_tiles is divisible by k_group_size
Parametersβ
- βInType (
DType): Input data type. - βOutType (
DType): Output data type. - βwarp_block_layout_a (
Layout): Layout for matrix A warp tiles. - βwarp_block_layout_b (
Layout): Layout for matrix B warp tiles. - βmma_shape (
IndexList[Int(3)]): Shape of the MMA operation [M, N, K]. - βswizzle (
Optional[Swizzle]): Optional swizzle pattern for memory access. - βtranspose_b (
Bool): Whether matrix B is transposed.
Fieldsβ
- βout_reg_tile (
AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].OutRegTile):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
ARegTileβ
comptime ARegTile = LayoutTensor[InType, Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())) * (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])), simd_width_of[InType]()), MutAnyOrigin, address_space=AddressSpace.LOCAL]
BRegTileβ
comptime BRegTile = LayoutTensor[InType, Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), simd_width_of[InType]()), MutAnyOrigin, address_space=AddressSpace.LOCAL]
k_group_size_aβ
comptime k_group_size_a = (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())
k_group_size_bβ
comptime k_group_size_b = (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())
k_tile_fragment_indexβ
comptime k_tile_fragment_index[k_tile_idx: Int] = (k_tile_idx % (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]()))
Parametersβ
- βk_tile_idx (
Int):
k_tile_group_indexβ
comptime k_tile_group_index[k_tile_idx: Int] = (k_tile_idx // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]()))
Parametersβ
- βk_tile_idx (
Int):
num_k_tilesβ
comptime num_k_tiles = (product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)])
num_m_mmasβ
comptime num_m_mmas = (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])
num_n_mmasβ
comptime num_n_mmas = (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])
out_frag_sizeβ
comptime out_frag_size = (Int((mul mma_shape[Int(0)], mma_shape[Int(1)])) // _resolve_warp_size())
OutRegTileβ
comptime OutRegTile = LayoutTensor[OutType, Layout.row_major(Int((mul (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]))), SIMD(SIMDSize(num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(1)]]()))), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=Int((get_alignof SIMD[InType, SIMDSize(simd_width_of[InType]())], _current_target()))]
OutRegTileFragmentTypeβ
comptime OutRegTileFragmentType = LayoutTensor[OutType, LayoutTensor._compute_tile_layout[((product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), SIMD(SIMDSize(num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(1)]]()))]()[Int(0)], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major(Int((mul (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]))), SIMD(SIMDSize(num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(1)]]()))), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major(Int((mul (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]))), SIMD(SIMDSize(num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(1)]]()))), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major(Int((mul (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]))), SIMD(SIMDSize(num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(1)]]()))), ((product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), SIMD(SIMDSize(num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(1)]]()))](), alignment=Int((get_alignof SIMD[InType, SIMDSize(simd_width_of[InType]())], _current_target()))]
simd_widthβ
comptime simd_width = simd_width_of[InType]()
tensor_coreβ
comptime tensor_core = TensorCore()
total_k_tilesβ
comptime total_k_tiles = AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles
WKβ
comptime WK = product(warp_block_layout_a.shape[1])
Methodsβ
__init__β
def __init__() -> Self
a_reg_tileβ
def a_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[(product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), simd_width_of[InType]()]()[Int(0)], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())) * (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])), simd_width_of[InType]()), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())) * (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])), simd_width_of[InType]()), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())) * (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])), simd_width_of[InType]()), (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), simd_width_of[InType]()]()]
Get A register tile for a specific K tile.
Returns:
b_reg_tileβ
def b_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[(product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]), simd_width_of[InType]()]()[Int(0)], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), simd_width_of[InType]()), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), simd_width_of[InType]()), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), simd_width_of[InType]()), (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]), simd_width_of[InType]()]()]
Get B register tile for a specific K tile.
Returns:
reset_accumulatorβ
def reset_accumulator(self)
Reset the accumulator to zero for a new tile computation.
load_tile_fragmentβ
def load_tile_fragment[k_tile_idx: Int](self, smem_tile_a: LayoutTensor[address_space=smem_tile_a.address_space, element_layout=smem_tile_a.element_layout, layout_int_type=smem_tile_a.layout_int_type, linear_idx_type=smem_tile_a.linear_idx_type, masked=smem_tile_a.masked, alignment=smem_tile_a.alignment], smem_tile_b: LayoutTensor[address_space=smem_tile_b.address_space, element_layout=smem_tile_b.element_layout, layout_int_type=smem_tile_b.layout_int_type, linear_idx_type=smem_tile_b.linear_idx_type, masked=smem_tile_b.masked, alignment=smem_tile_b.alignment])
Load fragments from shared memory to registers for a specific K tile.
Parameters:
- βk_tile_idx (
Int): K-tile index (0 to total_k_tiles-1).
Args:
- βsmem_tile_a (
LayoutTensor[address_space=smem_tile_a.address_space, element_layout=smem_tile_a.element_layout, layout_int_type=smem_tile_a.layout_int_type, linear_idx_type=smem_tile_a.linear_idx_type, masked=smem_tile_a.masked, alignment=smem_tile_a.alignment]): Shared memory tile for matrix A. - βsmem_tile_b (
LayoutTensor[address_space=smem_tile_b.address_space, element_layout=smem_tile_b.element_layout, layout_int_type=smem_tile_b.layout_int_type, linear_idx_type=smem_tile_b.linear_idx_type, masked=smem_tile_b.masked, alignment=smem_tile_b.alignment]): Shared memory tile for matrix B.
mma_computeβ
def mma_compute[k_tile_idx: Int](self)
Perform matrix multiply-accumulate for a specific K tile.
This method assumes fragments are already loaded via load_tile_fragment.
Parameters:
- βk_tile_idx (
Int): K-tile index (0 to total_k_tiles-1).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!