Mojo struct
AmdTileOperator
struct AmdTileOperator[InType: DType, OutType: DType, warp_block_layout_a: Layout, warp_block_layout_b: Layout, mma_shape: IndexList[3], swizzle: Optional[Swizzle] = None, transpose_b: Bool = True]
Manages tensor core operations for matrix multiplication on AMD GPUs.
This operator handles loading matrix fragments from shared memory to registers and performing matrix multiply-accumulate operations using tensor cores.
Requirements: - warp_block_layout_a.shape[0] must be divisible by mma_shape[0] - warp_block_layout_b.shape[0] must be divisible by mma_shape[1] - warp_block_layout_a.shape[1] must be divisible by mma_shape[2] - warp_block_layout_b.shape[1] must be divisible by mma_shape[2] - The K dimension must align such that num_k_tiles is divisible by k_group_size
Parametersβ
- βInType (
DType): Input data type. - βOutType (
DType): Output data type. - βwarp_block_layout_a (
Layout): Layout for matrix A warp tiles. - βwarp_block_layout_b (
Layout): Layout for matrix B warp tiles. - βmma_shape (
IndexList[3]): Shape of the MMA operation [M, N, K]. - βswizzle (
Optional[Swizzle]): Optional swizzle pattern for memory access. - βtranspose_b (
Bool): Whether matrix B is transposed.
Fieldsβ
- βout_reg_tile (
AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].OutRegTile):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
ARegTileβ
comptime ARegTile = LayoutTensor[InType, Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]
BRegTileβ
comptime BRegTile = LayoutTensor[InType, Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]
k_group_size_aβ
comptime k_group_size_a = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width // num_matrix_reg[mma_shape[0], mma_shape[2]]())
k_group_size_bβ
comptime k_group_size_b = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width // num_matrix_reg[mma_shape[1], mma_shape[2]]())
k_tile_fragment_indexβ
comptime k_tile_fragment_index[k_tile_idx: Int] = (k_tile_idx % AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a)
Parametersβ
- βk_tile_idx (
Int):
k_tile_group_indexβ
comptime k_tile_group_index[k_tile_idx: Int] = (k_tile_idx // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a)
Parametersβ
- βk_tile_idx (
Int):
num_k_tilesβ
comptime num_k_tiles = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].WK // mma_shape[2])
num_m_mmasβ
comptime num_m_mmas = (product(warp_block_layout_a.shape[0]) // mma_shape[0])
num_n_mmasβ
comptime num_n_mmas = (product(warp_block_layout_b.shape[0]) // mma_shape[1])
out_frag_sizeβ
comptime out_frag_size = ((mma_shape[0] * mma_shape[1]) // WARP_SIZE)
OutRegTileβ
comptime OutRegTile = LayoutTensor[OutType, Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape[0], mma_shape[1]]()), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=align_of[SIMD[InType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]]()]
OutRegTileFragmentTypeβ
comptime OutRegTileFragmentType = LayoutTensor[OutType, LayoutTensor._compute_tile_layout[(AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape[0], mma_shape[1]]()]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape[0], mma_shape[1]]()), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape[0], mma_shape[1]]()), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape[0], mma_shape[1]]()), (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape[0], mma_shape[1]]()](), alignment=align_of[SIMD[InType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]]()]
simd_widthβ
comptime simd_width = simd_width_of[InType]()
tensor_coreβ
comptime tensor_core = TensorCore()
total_k_tilesβ
comptime total_k_tiles = AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles
WKβ
comptime WK = product(warp_block_layout_a.shape[1])
Methodsβ
__init__β
__init__() -> Self
a_reg_tileβ
a_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()]
Get A register tile for a specific K tile.
Returns:
b_reg_tileβ
b_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()]
Get B register tile for a specific K tile.
Returns:
reset_accumulatorβ
reset_accumulator(self)
Reset the accumulator to zero for a new tile computation.
load_tile_fragmentβ
load_tile_fragment[k_tile_idx: Int](self, smem_tile_a: LayoutTensor[address_space=smem_tile_a.address_space, element_layout=smem_tile_a.element_layout, layout_int_type=smem_tile_a.layout_int_type, linear_idx_type=smem_tile_a.linear_idx_type, masked=smem_tile_a.masked, alignment=smem_tile_a.alignment], smem_tile_b: LayoutTensor[address_space=smem_tile_b.address_space, element_layout=smem_tile_b.element_layout, layout_int_type=smem_tile_b.layout_int_type, linear_idx_type=smem_tile_b.linear_idx_type, masked=smem_tile_b.masked, alignment=smem_tile_b.alignment])
Load fragments from shared memory to registers for a specific K tile.
Parameters:
- βk_tile_idx (
Int): K-tile index (0 to total_k_tiles-1).
Args:
- βsmem_tile_a (
LayoutTensor[address_space=smem_tile_a.address_space, element_layout=smem_tile_a.element_layout, layout_int_type=smem_tile_a.layout_int_type, linear_idx_type=smem_tile_a.linear_idx_type, masked=smem_tile_a.masked, alignment=smem_tile_a.alignment]): Shared memory tile for matrix A. - βsmem_tile_b (
LayoutTensor[address_space=smem_tile_b.address_space, element_layout=smem_tile_b.element_layout, layout_int_type=smem_tile_b.layout_int_type, linear_idx_type=smem_tile_b.linear_idx_type, masked=smem_tile_b.masked, alignment=smem_tile_b.alignment]): Shared memory tile for matrix B.
mma_computeβ
mma_compute[k_tile_idx: Int](self)
Perform matrix multiply-accumulate for a specific K tile.
This method assumes fragments are already loaded via load_tile_fragment.
Parameters:
- βk_tile_idx (
Int): K-tile index (0 to total_k_tiles-1).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!