Mojo struct
MmaOpAMD
struct MmaOpAMD[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool, k_group_size: Int, num_k_tiles: Int, num_m_mmas: Int, num_n_mmas: Int, BK: Int, WK: Int]
Fields
- a_reg_tile (
StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles]
): - b_reg_tile (
StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles]
): - out_reg_tile (
LayoutTensor[out_type, Layout.row_major((num_m_mmas * num_n_mmas), 4), MutableAnyOrigin, address_space=AddressSpace(5)]
):
Implemented traits
AnyType
,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = LayoutTensor[out_type, Layout.row_major((num_m_mmas * num_n_mmas), 4), MutableAnyOrigin, address_space=AddressSpace(5)].__del__is_trivial if StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial if StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial else StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial else StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial if StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial else StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial
alignment
alias alignment = align_of[SIMD[in_type, simd_width_of[in_type]()]]()
OutRegTileType
alias OutRegTileType = LayoutTensor[out_type, Layout.row_major((num_m_mmas * num_n_mmas), 4), MutableAnyOrigin, address_space=AddressSpace(5)]
RegTileKType
alias RegTileKType[num_mmas: Int] = StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles]
Parameters
- num_mmas (
Int
):
RegTileType
alias RegTileType[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()]
Parameters
- num_mmas (
Int
):
simd_width
alias simd_width = simd_width_of[in_type]()
SMemTileType
alias SMemTileType[smem_layout: Layout] = LayoutTensor[in_type, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()]
Parameters
- smem_layout (
Layout
):
SMemWarpTileType
alias SMemWarpTileType[warp_rows: Int, smem_layout: Layout] = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), warp_rows, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, WK](), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()]
Parameters
swizzle
alias swizzle = Swizzle(3, 0, 1)
tensor_core_mma
alias tensor_core_mma = TensorCoreKGroup[out_type, in_type, shape, k_group_size, transpose_b]()
Methods
__init__
__init__(out self)
get_smem_layout
mma
mma[k_tile_idx: Int](self)
load_tile_from_smem
load_tile_from_smem[k_tile_idx: Int](self, a_smem_tiles: LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), warp_rows, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, WK](), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], b_smem_tiles: LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), warp_rows, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, WK](), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()])
reset_accumulator
reset_accumulator(self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!