Skip to main content

Mojo struct

MmaOpAMD

struct MmaOpAMD[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool, k_group_size: Int, num_k_tiles: Int, num_m_mmas: Int, num_n_mmas: Int, BK: Int, WK: Int]

Fields

  • a_reg_tile (StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles]):
  • b_reg_tile (StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles]):
  • out_reg_tile (LayoutTensor[out_type, Layout.row_major((num_m_mmas * num_n_mmas), 4), MutableAnyOrigin, address_space=AddressSpace(5)]):

Implemented traits

AnyType, UnknownDestructibility

Aliases

__del__is_trivial

alias __del__is_trivial = LayoutTensor[out_type, Layout.row_major((num_m_mmas * num_n_mmas), 4), MutableAnyOrigin, address_space=AddressSpace(5)].__del__is_trivial if StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial if StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial else StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial else StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial if StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial else StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial

alignment

alias alignment = align_of[SIMD[in_type, simd_width_of[in_type]()]]()

OutRegTileType

alias OutRegTileType = LayoutTensor[out_type, Layout.row_major((num_m_mmas * num_n_mmas), 4), MutableAnyOrigin, address_space=AddressSpace(5)]

RegTileKType

alias RegTileKType[num_mmas: Int] = StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), 0 if (num_k_tiles == 0) else ((div_s Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) + -1) if ((((rem_s Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else (div_s Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]()._mlir_value, 1 if (num_k_tiles == 0) else num_k_tiles._mlir_value), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles]

Parameters

  • num_mmas (Int):

RegTileType

alias RegTileType[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()]

Parameters

  • num_mmas (Int):

simd_width

alias simd_width = simd_width_of[in_type]()

SMemTileType

alias SMemTileType[smem_layout: Layout] = LayoutTensor[in_type, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()]

Parameters

SMemWarpTileType

alias SMemWarpTileType[warp_rows: Int, smem_layout: Layout] = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), warp_rows, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, WK](), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()]

Parameters

  • warp_rows (Int):
  • smem_layout (Layout):

swizzle

alias swizzle = Swizzle(3, 0, 1)

tensor_core_mma

alias tensor_core_mma = TensorCoreKGroup[out_type, in_type, shape, k_group_size, transpose_b]()

Methods

__init__

__init__(out self)

get_smem_layout

static get_smem_layout[block_rows: Int, k_tile_size: Int]() -> Layout

Returns:

Layout

mma

mma[k_tile_idx: Int](self)

load_tile_from_smem

load_tile_from_smem[k_tile_idx: Int](self, a_smem_tiles: LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), warp_rows, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, WK](), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], b_smem_tiles: LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), warp_rows, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, WK](), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()])

reset_accumulator

reset_accumulator(self)

Was this page helpful?