Skip to main content

Mojo struct

AmdTileOperator

struct AmdTileOperator[InType: DType, OutType: DType, warp_block_layout_a: Layout, warp_block_layout_b: Layout, mma_shape: IndexList[3], swizzle: Optional[Swizzle] = None, transpose_b: Bool = True]

Manages tensor core operations for matrix multiplication on AMD GPUs.

This operator handles loading matrix fragments from shared memory to registers and performing matrix multiply-accumulate operations using tensor cores.

Requirements: - warp_block_layout_a.shape[0] must be divisible by mma_shape[0] - warp_block_layout_b.shape[0] must be divisible by mma_shape[1] - warp_block_layout_a.shape[1] must be divisible by mma_shape[2] - warp_block_layout_b.shape[1] must be divisible by mma_shape[2] - The K dimension must align such that num_k_tiles is divisible by k_group_size

Parameters​

  • ​InType (DType): Input data type.
  • ​OutType (DType): Output data type.
  • ​warp_block_layout_a (Layout): Layout for matrix A warp tiles.
  • ​warp_block_layout_b (Layout): Layout for matrix B warp tiles.
  • ​mma_shape (IndexList[3]): Shape of the MMA operation [M, N, K].
  • ​swizzle (Optional[Swizzle]): Optional swizzle pattern for memory access.
  • ​transpose_b (Bool): Whether matrix B is transposed.

Fields​

  • ​out_reg_tile (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].OutRegTile):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

ARegTile​

comptime ARegTile = LayoutTensor[InType, Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]

BRegTile​

comptime BRegTile = LayoutTensor[InType, Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]

k_group_size_a​

comptime k_group_size_a = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width // num_matrix_reg[mma_shape[0], mma_shape[2]]())

k_group_size_b​

comptime k_group_size_b = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width // num_matrix_reg[mma_shape[1], mma_shape[2]]())

k_tile_fragment_index​

comptime k_tile_fragment_index[k_tile_idx: Int] = (k_tile_idx % AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a)

Parameters​

  • ​k_tile_idx (Int):

k_tile_group_index​

comptime k_tile_group_index[k_tile_idx: Int] = (k_tile_idx // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a)

Parameters​

  • ​k_tile_idx (Int):

num_k_tiles​

comptime num_k_tiles = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].WK // mma_shape[2])

num_m_mmas​

comptime num_m_mmas = (product(warp_block_layout_a.shape[0]) // mma_shape[0])

num_n_mmas​

comptime num_n_mmas = (product(warp_block_layout_b.shape[0]) // mma_shape[1])

out_frag_size​

comptime out_frag_size = ((mma_shape[0] * mma_shape[1]) // WARP_SIZE)

OutRegTile​

comptime OutRegTile = LayoutTensor[OutType, Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape[0], mma_shape[1]]()), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=align_of[SIMD[InType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]]()]

OutRegTileFragmentType​

comptime OutRegTileFragmentType = LayoutTensor[OutType, LayoutTensor._compute_tile_layout[(AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape[0], mma_shape[1]]()]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape[0], mma_shape[1]]()), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape[0], mma_shape[1]]()), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape[0], mma_shape[1]]()), (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape[0], mma_shape[1]]()](), alignment=align_of[SIMD[InType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]]()]

simd_width​

comptime simd_width = simd_width_of[InType]()

tensor_core​

comptime tensor_core = TensorCore()

total_k_tiles​

comptime total_k_tiles = AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles

WK​

comptime WK = product(warp_block_layout_a.shape[1])

Methods​

__init__​

__init__() -> Self

a_reg_tile​

a_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()]

Get A register tile for a specific K tile.

Returns:

LayoutTensor[InType, LayoutTensor._compute_tile_layout[AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()]

b_reg_tile​

b_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()]

Get B register tile for a specific K tile.

Returns:

LayoutTensor[InType, LayoutTensor._compute_tile_layout[AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()]

reset_accumulator​

reset_accumulator(self)

Reset the accumulator to zero for a new tile computation.

load_tile_fragment​

load_tile_fragment[k_tile_idx: Int](self, smem_tile_a: LayoutTensor[address_space=smem_tile_a.address_space, element_layout=smem_tile_a.element_layout, layout_int_type=smem_tile_a.layout_int_type, linear_idx_type=smem_tile_a.linear_idx_type, masked=smem_tile_a.masked, alignment=smem_tile_a.alignment], smem_tile_b: LayoutTensor[address_space=smem_tile_b.address_space, element_layout=smem_tile_b.element_layout, layout_int_type=smem_tile_b.layout_int_type, linear_idx_type=smem_tile_b.linear_idx_type, masked=smem_tile_b.masked, alignment=smem_tile_b.alignment])

Load fragments from shared memory to registers for a specific K tile.

Parameters:

  • ​k_tile_idx (Int): K-tile index (0 to total_k_tiles-1).

Args:

mma_compute​

mma_compute[k_tile_idx: Int](self)

Perform matrix multiply-accumulate for a specific K tile.

This method assumes fragments are already loaded via load_tile_fragment.

Parameters:

  • ​k_tile_idx (Int): K-tile index (0 to total_k_tiles-1).