Skip to main content

Mojo struct

AmdTileOperator

struct AmdTileOperator[InType: DType, OutType: DType, mma_shape: IndexList[3], transpose_b: Bool, //, mma_config: AnyStruct[MMAConfig[InType, OutType, mma_shape, transpose_b]], warp_block_layout_a: Layout, warp_block_layout_b: Layout, swizzle: OptionalReg[Swizzle] = None, tile_being_processed_per_warp: Int = 1]

Fields

  • full_c_reg_tile (LayoutTensor[OutType, pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]):
  • a_reg_tile (LayoutTensor[InType, Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width_of[InType]() // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](2)]())) * (product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0))), simd_width_of[InType]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]):
  • b_reg_tile (LayoutTensor[InType, Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width_of[InType]() // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](1), mma_shape.__getitem__[3, DType.int64, Int](2)]())) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), simd_width_of[InType]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]):

Implemented traits

AnyType, UnknownDestructibility

Aliases

__del__is_trivial

alias __del__is_trivial = True

in_layout

alias in_layout[num_mmas: Int, k_tiles_per_simd: Int] = Layout.row_major((k_tiles_per_simd * num_mmas), simd_width_of[InType]())

Parameters

  • num_mmas (Int):
  • k_tiles_per_simd (Int):

InMmaFragmentTypeA

alias InMmaFragmentTypeA = LayoutTensor[InType, Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width_of[InType]() // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](2)]())) * (product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0))), simd_width_of[InType]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]

InMmaFragmentTypeB

alias InMmaFragmentTypeB = LayoutTensor[InType, Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width_of[InType]() // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](1), mma_shape.__getitem__[3, DType.int64, Int](2)]())) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), simd_width_of[InType]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]

k_tiles_per_simd_a

alias k_tiles_per_simd_a = ((product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width_of[InType]() // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](2)]()))

k_tiles_per_simd_b

alias k_tiles_per_simd_b = ((product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width_of[InType]() // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](1), mma_shape.__getitem__[3, DType.int64, Int](2)]()))

num_k_tiles

alias num_k_tiles = (product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2))

num_m_mmas

alias num_m_mmas = (product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0))

num_n_mmas

alias num_n_mmas = (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))

out_frag_cols

alias out_frag_cols = num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()

out_frag_rows

alias out_frag_rows = ((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1)))

out_mma_fragment_layout

alias out_mma_fragment_layout = pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp]()

OutMmaFragmentTileType

alias OutMmaFragmentTileType = LayoutTensor[OutType, LayoutTensor._compute_tile_layout[True, OutType, pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), _get_index_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), False, align_of[SIMD[InType, simd_width_of[InType]()]](), ((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()]()[0], MutableAnyOrigin, address_space=AddressSpace(5), layout_int_type=_get_layout_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), linear_idx_type=_get_index_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), masked=_tile_is_masked[pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), ((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()](), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]

OutMmaFragmentType

alias OutMmaFragmentType = LayoutTensor[OutType, pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]

type_alignment

alias type_alignment = align_of[SIMD[InType, simd_width_of[InType]()]]()

WK

alias WK = product(warp_block_layout_a.shape[1])

Methods

__init__

__init__(out self)

get_c_reg_tile_slice

get_c_reg_tile_slice(self, tile_idx: Int) -> LayoutTensor[OutType, LayoutTensor._compute_tile_layout[True, OutType, pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), _get_index_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), False, align_of[SIMD[InType, simd_width_of[InType]()]](), ((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()]()[0], MutableAnyOrigin, address_space=AddressSpace(5), layout_int_type=_get_layout_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), linear_idx_type=_get_index_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), masked=_tile_is_masked[pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), ((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()](), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]

Returns:

LayoutTensor

mma

mma[swap_a_b: Bool = True](mut self, mut cache_manager: RingBuffer[SmemBufferTypeA, SmemBufferTypeB, consumer_warps], mut phase_a: Int, mut phase_b: Int, stage: Int, smem_warp_tile_idx_a: Int, smem_warp_tile_idx_b: Int, linear_warp_idx: Int, block_tile_num: Int)

Was this page helpful?