Skip to main content

Mojo struct

MmaOpAMD

struct MmaOpAMD[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool, k_group_size: Int, num_k_tiles: Int, num_m_mmas: Int, num_n_mmas: Int, BK: Int, WK: Int]

Fields

  • a_reg_tile (StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles]):
  • b_reg_tile (StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles]):
  • out_reg_tile (LayoutTensor[out_type, Layout.row_major((num_m_mmas * num_n_mmas), 4), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[out_type, simd_width_of[out_type]()]]()]):

Implemented traits

AnyType, UnknownDestructibility

Aliases

__del__is_trivial

alias __del__is_trivial = LayoutTensor[out_type, Layout.row_major((num_m_mmas * num_n_mmas), 4), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[out_type, simd_width_of[out_type]()]]()].__del__is_trivial if StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial if StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial else StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial else StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial if StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial else StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial

alignment

alias alignment = align_of[SIMD[in_type, simd_width_of[in_type]()]]()

out_reg_layout

alias out_reg_layout = Layout.row_major((num_m_mmas * num_n_mmas), 4)

OutRegTileType

alias OutRegTileType = LayoutTensor[out_type, Layout.row_major((num_m_mmas * num_n_mmas), 4), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[out_type, simd_width_of[out_type]()]]()]

reg_tile_layout

alias reg_tile_layout[num_mmas: Int] = Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]())

Parameters

  • num_mmas (Int):

RegTileFragType

alias RegTileFragType[num_mmas: Int] = StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles]

Parameters

  • num_mmas (Int):

RegTileType

alias RegTileType[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()]

Parameters

  • num_mmas (Int):

simd_width

alias simd_width = simd_width_of[in_type]()

swizzle

alias swizzle = Swizzle(3, 0, 1)

tensor_core_mma

alias tensor_core_mma = TensorCoreKGroup[out_type, in_type, shape, k_group_size, transpose_b]()

Methods

__init__

__init__(out self)

smem_tile_layout

static smem_tile_layout[block_rows: Int, k_tile_size: Int]() -> Layout

Returns:

Layout

mma

mma[k_tile_idx: Int](self)

load_tile_fragment

load_tile_fragment[k_tile_idx: Int](self, a_smem_tiles: LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(layout, AddressSpace(3)), _get_index_type(layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(layout, AddressSpace(3)), linear_idx_type=_get_index_type(layout, AddressSpace(3)), masked=_tile_is_masked[layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()], b_smem_tiles: LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(layout, AddressSpace(3)), _get_index_type(layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(layout, AddressSpace(3)), linear_idx_type=_get_index_type(layout, AddressSpace(3)), masked=_tile_is_masked[layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()])

reset_accumulator

reset_accumulator(self)

Was this page helpful?