IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

AmdTileOperator

struct AmdTileOperator[InType: DType, OutType: DType, warp_block_layout_a: Layout, warp_block_layout_b: Layout, mma_shape: IndexList[Int(3)], swizzle: Optional[Swizzle] = None, transpose_b: Bool = True]

Manages tensor core operations for matrix multiplication on AMD GPUs.

This operator handles loading matrix fragments from shared memory to registers and performing matrix multiply-accumulate operations using tensor cores.

Requirements: - warp_block_layout_a.shape[0] must be divisible by mma_shape[0] - warp_block_layout_b.shape[0] must be divisible by mma_shape[1] - warp_block_layout_a.shape[1] must be divisible by mma_shape[2] - warp_block_layout_b.shape[1] must be divisible by mma_shape[2] - The K dimension must align such that num_k_tiles is divisible by k_group_size

Parameters​

  • ​InType (DType): Input data type.
  • ​OutType (DType): Output data type.
  • ​warp_block_layout_a (Layout): Layout for matrix A warp tiles.
  • ​warp_block_layout_b (Layout): Layout for matrix B warp tiles.
  • ​mma_shape (IndexList[Int(3)]): Shape of the MMA operation [M, N, K].
  • ​swizzle (Optional[Swizzle]): Optional swizzle pattern for memory access.
  • ​transpose_b (Bool): Whether matrix B is transposed.

Fields​

  • ​out_reg_tile (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].OutRegTile):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

ARegTile​

comptime ARegTile = LayoutTensor[InType, Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())) * (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])), simd_width_of[InType]()), MutAnyOrigin, address_space=AddressSpace.LOCAL]

BRegTile​

comptime BRegTile = LayoutTensor[InType, Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), simd_width_of[InType]()), MutAnyOrigin, address_space=AddressSpace.LOCAL]

k_group_size_a​

comptime k_group_size_a = (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())

k_group_size_b​

comptime k_group_size_b = (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())

k_tile_fragment_index​

comptime k_tile_fragment_index[k_tile_idx: Int] = (k_tile_idx % (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]()))

Parameters​

  • ​k_tile_idx (Int):

k_tile_group_index​

comptime k_tile_group_index[k_tile_idx: Int] = (k_tile_idx // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]()))

Parameters​

  • ​k_tile_idx (Int):

num_k_tiles​

comptime num_k_tiles = (product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)])

num_m_mmas​

comptime num_m_mmas = (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])

num_n_mmas​

comptime num_n_mmas = (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])

out_frag_size​

comptime out_frag_size = (Int((mul mma_shape[Int(0)], mma_shape[Int(1)])) // _resolve_warp_size())

OutRegTile​

comptime OutRegTile = LayoutTensor[OutType, Layout.row_major(Int((mul (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]))), SIMD(SIMDSize(num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(1)]]()))), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=Int((get_alignof SIMD[InType, SIMDSize(simd_width_of[InType]())], _current_target()))]

OutRegTileFragmentType​

comptime OutRegTileFragmentType = LayoutTensor[OutType, LayoutTensor._compute_tile_layout[((product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), SIMD(SIMDSize(num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(1)]]()))]()[Int(0)], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major(Int((mul (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]))), SIMD(SIMDSize(num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(1)]]()))), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major(Int((mul (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]))), SIMD(SIMDSize(num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(1)]]()))), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major(Int((mul (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]))), SIMD(SIMDSize(num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(1)]]()))), ((product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), SIMD(SIMDSize(num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(1)]]()))](), alignment=Int((get_alignof SIMD[InType, SIMDSize(simd_width_of[InType]())], _current_target()))]

simd_width​

comptime simd_width = simd_width_of[InType]()

tensor_core​

comptime tensor_core = TensorCore()

total_k_tiles​

comptime total_k_tiles = AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles

WK​

comptime WK = product(warp_block_layout_a.shape[1])

Methods​

__init__​

def __init__() -> Self

a_reg_tile​

def a_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[(product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), simd_width_of[InType]()]()[Int(0)], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())) * (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])), simd_width_of[InType]()), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())) * (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])), simd_width_of[InType]()), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())) * (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])), simd_width_of[InType]()), (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), simd_width_of[InType]()]()]

Get A register tile for a specific K tile.

Returns:

LayoutTensor[InType, LayoutTensor._compute_tile_layout[(product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), simd_width_of[InType]()]()[Int(0)], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())) * (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])), simd_width_of[InType]()), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())) * (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])), simd_width_of[InType]()), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]())) * (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)])), simd_width_of[InType]()), (product(warp_block_layout_a.shape[0]) // mma_shape[Int(0)]), simd_width_of[InType]()]()]

b_reg_tile​

def b_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[(product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]), simd_width_of[InType]()]()[Int(0)], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), simd_width_of[InType]()), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), simd_width_of[InType]()), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), simd_width_of[InType]()), (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]), simd_width_of[InType]()]()]

Get B register tile for a specific K tile.

Returns:

LayoutTensor[InType, LayoutTensor._compute_tile_layout[(product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]), simd_width_of[InType]()]()[Int(0)], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), simd_width_of[InType]()), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), simd_width_of[InType]()), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape[Int(2)]) // (simd_width_of[InType]() // num_matrix_reg[mma_shape[Int(1)], mma_shape[Int(2)]]())) * (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)])), simd_width_of[InType]()), (product(warp_block_layout_b.shape[0]) // mma_shape[Int(1)]), simd_width_of[InType]()]()]

reset_accumulator​

def reset_accumulator(self)

Reset the accumulator to zero for a new tile computation.

load_tile_fragment​

def load_tile_fragment[k_tile_idx: Int](self, smem_tile_a: LayoutTensor[address_space=smem_tile_a.address_space, element_layout=smem_tile_a.element_layout, layout_int_type=smem_tile_a.layout_int_type, linear_idx_type=smem_tile_a.linear_idx_type, masked=smem_tile_a.masked, alignment=smem_tile_a.alignment], smem_tile_b: LayoutTensor[address_space=smem_tile_b.address_space, element_layout=smem_tile_b.element_layout, layout_int_type=smem_tile_b.layout_int_type, linear_idx_type=smem_tile_b.linear_idx_type, masked=smem_tile_b.masked, alignment=smem_tile_b.alignment])

Load fragments from shared memory to registers for a specific K tile.

Parameters:

  • ​k_tile_idx (Int): K-tile index (0 to total_k_tiles-1).

Args:

mma_compute​

def mma_compute[k_tile_idx: Int](self)

Perform matrix multiply-accumulate for a specific K tile.

This method assumes fragments are already loaded via load_tile_fragment.

Parameters:

  • ​k_tile_idx (Int): K-tile index (0 to total_k_tiles-1).