Mojo struct
AmdTileOperator
@register_passable(trivial)
struct AmdTileOperator[InType: DType, OutType: DType, warp_block_layout_a: Layout, warp_block_layout_b: Layout, mma_shape: IndexList[3], swizzle: OptionalReg[Swizzle] = None, transpose_b: Bool = True]
Manages tensor core operations for matrix multiplication on AMD GPUs.
This operator handles loading matrix fragments from shared memory to registers and performing matrix multiply-accumulate operations using tensor cores.
Requirements: - warp_block_layout_a.shape[0] must be divisible by mma_shape[0] - warp_block_layout_b.shape[0] must be divisible by mma_shape[1] - warp_block_layout_a.shape[1] must be divisible by mma_shape[2] - warp_block_layout_b.shape[1] must be divisible by mma_shape[2] - The K dimension must align such that num_k_tiles is divisible by k_group_size
Parameters
- InType (
DType): Input data type. - OutType (
DType): Output data type. - warp_block_layout_a (
Layout): Layout for matrix A warp tiles. - warp_block_layout_b (
Layout): Layout for matrix B warp tiles. - mma_shape (
IndexList): Shape of the MMA operation [M, N, K]. - swizzle (
OptionalReg): Optional swizzle pattern for memory access. - transpose_b (
Bool): Whether matrix B is transposed.
Fields
- out_reg_tile (
LayoutTensor[OutType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._out_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._type_alignment]):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
ARegTileType
alias ARegTileType = LayoutTensor[InType, Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_a * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]
BRegTileType
alias BRegTileType = LayoutTensor[InType, Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_b * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]
k_group_size_a
alias k_group_size_a = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._registers_per_thread_a)
k_group_size_b
alias k_group_size_b = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._registers_per_thread_b)
k_tile_fragment_index
alias k_tile_fragment_index[k_tile_idx: Int] = (k_tile_idx % AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a)
Parameters
- k_tile_idx (
Int):
k_tile_group_index
alias k_tile_group_index[k_tile_idx: Int] = (k_tile_idx // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a)
Parameters
- k_tile_idx (
Int):
num_k_tiles
alias num_k_tiles = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].WK // mma_shape.__getitem__[3, int64, Int](2))
num_m_mmas
alias num_m_mmas = (product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, int64, Int](0))
num_n_mmas
alias num_n_mmas = (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, int64, Int](1))
out_frag_size
alias out_frag_size = ((mma_shape.__getitem__[3, int64, Int](0) * mma_shape.__getitem__[3, int64, Int](1)) // WARP_SIZE)
OutRegTileFragmentType
alias OutRegTileFragmentType = LayoutTensor[OutType, LayoutTensor._compute_tile_layout[True, OutType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._out_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._out_layout, AddressSpace.LOCAL), _get_index_type(AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._out_layout, AddressSpace.LOCAL), False, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._type_alignment, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._out_frag_rows, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._out_frag_cols]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._out_layout, AddressSpace.LOCAL), linear_idx_type=_get_index_type(AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._out_layout, AddressSpace.LOCAL), masked=_tile_is_masked[AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._out_layout, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._out_frag_rows, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._out_frag_cols](), alignment=AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._type_alignment]
OutRegTileType
alias OutRegTileType = LayoutTensor[OutType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._out_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._type_alignment]
simd_width
alias simd_width = simd_width_of[InType]()
tensor_core
alias tensor_core = TensorCore[OutType, InType, mma_shape, transpose_b]()
total_k_tiles
alias total_k_tiles = AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles
WK
alias WK = product(warp_block_layout_a.shape[1])
Methods
__init__
__init__() -> Self
a_reg_tile
a_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[True, InType, Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_a * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_a * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_a * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), False, align_of[InType](), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_a * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_a * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_a * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()]
Get A register tile for a specific K tile.
Returns:
b_reg_tile
b_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[True, InType, Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_b * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_b * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_b * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), False, align_of[InType](), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_b * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_b * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b]._k_tiles_per_simd_b * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()]
Get B register tile for a specific K tile.
Returns:
reset_accumulator
reset_accumulator(self)
Reset the accumulator to zero for a new tile computation.
load_tile_fragment
load_tile_fragment[k_tile_idx: Int](self, smem_tile_a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], smem_tile_b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])
Load fragments from shared memory to registers for a specific K tile.
Parameters:
- k_tile_idx (
Int): K-tile index (0 to total_k_tiles-1).
Args:
- smem_tile_a (
LayoutTensor): Shared memory tile for matrix A. - smem_tile_b (
LayoutTensor): Shared memory tile for matrix B.
mma_compute
mma_compute[k_tile_idx: Int](self)
Perform matrix multiply-accumulate for a specific K tile.
This method assumes fragments are already loaded via load_tile_fragment.
Parameters:
- k_tile_idx (
Int): K-tile index (0 to total_k_tiles-1).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!