Mojo struct
MmaOpAMD
struct MmaOpAMD[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool, k_group_size: Int, num_k_tiles: Int, num_m_mmas: Int, num_n_mmas: Int, out_frag_size: Int, swizzle: Swizzle]
Fields
- out_reg_tile (
MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].OutRegTile):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
alignment
comptime alignment = align_of[SIMD[in_type, MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width]]()
out_reg_layout
comptime out_reg_layout = Layout.row_major((num_m_mmas * num_n_mmas), out_frag_size)
OutRegTile
comptime OutRegTile = LayoutTensor[out_type, MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].out_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]
reg_tile_layout
comptime reg_tile_layout[num_mmas: Int] = Layout.row_major((num_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width)
Parameters
- num_mmas (
Int):
RegTile
comptime RegTile[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major((num_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]
Parameters
- num_mmas (
Int):
simd_width
comptime simd_width = simd_width_of[in_type]()
tensor_core_mma
comptime tensor_core_mma = TiledTensorCore()
Methods
__init__
__init__(out self)
a_reg_tile
a_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[in_type, LayoutTensor._compute_tile_layout[num_m_mmas, simd_width_of[in_type]()]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), num_m_mmas, simd_width_of[in_type]()]()]
Returns:
b_reg_tile
b_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[in_type, LayoutTensor._compute_tile_layout[num_n_mmas, simd_width_of[in_type]()]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), num_n_mmas, simd_width_of[in_type]()]()]
Returns:
mma
mma[k_tile_idx: Int](self)
load_tile_fragment
load_tile_fragment[k_tile_idx: Int](self, a_smem_tiles: LayoutTensor[a_smem_tiles._dtype, LayoutTensor._compute_tile_layout[warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(layout, AddressSpace.SHARED), masked=_tile_is_masked[layout, warp_rows, warp_cols]()], b_smem_tiles: LayoutTensor[b_smem_tiles._dtype, LayoutTensor._compute_tile_layout[warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(layout, AddressSpace.SHARED), masked=_tile_is_masked[layout, warp_rows, warp_cols]()])
reset_accumulator
reset_accumulator(self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!