Skip to main content

Mojo struct

MmaOpAMD

struct MmaOpAMD[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool, k_group_size: Int, num_k_tiles: Int, num_m_mmas: Int, num_n_mmas: Int, out_frag_size: Int, swizzle: Swizzle]

Fields

  • out_reg_tile (MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].OutRegTile):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

alignment

comptime alignment = align_of[SIMD[in_type, MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width]]()

out_reg_layout

comptime out_reg_layout = Layout.row_major((num_m_mmas * num_n_mmas), out_frag_size)

OutRegTile

comptime OutRegTile = LayoutTensor[out_type, MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].out_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

reg_tile_layout

comptime reg_tile_layout[num_mmas: Int] = Layout.row_major((num_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width)

Parameters

  • num_mmas (Int):

RegTile

comptime RegTile[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major((num_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]

Parameters

  • num_mmas (Int):

simd_width

comptime simd_width = simd_width_of[in_type]()

tensor_core_mma

comptime tensor_core_mma = TiledTensorCore()

Methods

__init__

__init__(out self)

a_reg_tile

a_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[in_type, LayoutTensor._compute_tile_layout[num_m_mmas, simd_width_of[in_type]()]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), num_m_mmas, simd_width_of[in_type]()]()]

Returns:

LayoutTensor

b_reg_tile

b_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[in_type, LayoutTensor._compute_tile_layout[num_n_mmas, simd_width_of[in_type]()]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), num_n_mmas, simd_width_of[in_type]()]()]

Returns:

LayoutTensor

mma

mma[k_tile_idx: Int](self)

load_tile_fragment

load_tile_fragment[k_tile_idx: Int](self, a_smem_tiles: LayoutTensor[a_smem_tiles._dtype, LayoutTensor._compute_tile_layout[warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(layout, AddressSpace.SHARED), masked=_tile_is_masked[layout, warp_rows, warp_cols]()], b_smem_tiles: LayoutTensor[b_smem_tiles._dtype, LayoutTensor._compute_tile_layout[warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(layout, AddressSpace.SHARED), masked=_tile_is_masked[layout, warp_rows, warp_cols]()])

reset_accumulator

reset_accumulator(self)

Was this page helpful?