Mojo struct
AmdTileOperator
struct AmdTileOperator[InType: DType, OutType: DType, mma_shape: IndexList[3], transpose_b: Bool, //, mma_config: AnyStruct[MMAConfig[InType, OutType, mma_shape, transpose_b]], warp_block_layout_a: Layout, warp_block_layout_b: Layout, swizzle: OptionalReg[Swizzle] = None, tile_being_processed_per_warp: Int = 1]
Fields
- full_c_reg_tile (
LayoutTensor[OutType, pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]
): - a_reg_tile (
LayoutTensor[InType, Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width_of[InType]() // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](2)]())) * (product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0))), simd_width_of[InType]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]
): - b_reg_tile (
LayoutTensor[InType, Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width_of[InType]() // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](1), mma_shape.__getitem__[3, DType.int64, Int](2)]())) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), simd_width_of[InType]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]
):
Implemented traits
AnyType
,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = True
in_layout
alias in_layout[num_mmas: Int, k_tiles_per_simd: Int] = Layout.row_major((k_tiles_per_simd * num_mmas), simd_width_of[InType]())
Parameters
InMmaFragmentTypeA
alias InMmaFragmentTypeA = LayoutTensor[InType, Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width_of[InType]() // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](2)]())) * (product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0))), simd_width_of[InType]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]
InMmaFragmentTypeB
alias InMmaFragmentTypeB = LayoutTensor[InType, Layout.row_major((((product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width_of[InType]() // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](1), mma_shape.__getitem__[3, DType.int64, Int](2)]())) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), simd_width_of[InType]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]
k_tiles_per_simd_a
alias k_tiles_per_simd_a = ((product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width_of[InType]() // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](2)]()))
k_tiles_per_simd_b
alias k_tiles_per_simd_b = ((product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width_of[InType]() // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](1), mma_shape.__getitem__[3, DType.int64, Int](2)]()))
num_k_tiles
alias num_k_tiles = (product(warp_block_layout_a.shape[1]) // mma_shape.__getitem__[3, DType.int64, Int](2))
num_m_mmas
alias num_m_mmas = (product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0))
num_n_mmas
alias num_n_mmas = (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))
out_frag_cols
alias out_frag_cols = num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()
out_frag_rows
alias out_frag_rows = ((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1)))
out_mma_fragment_layout
alias out_mma_fragment_layout = pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp]()
OutMmaFragmentTileType
alias OutMmaFragmentTileType = LayoutTensor[OutType, LayoutTensor._compute_tile_layout[True, OutType, pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), _get_index_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), False, align_of[SIMD[InType, simd_width_of[InType]()]](), ((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()]()[0], MutableAnyOrigin, address_space=AddressSpace(5), layout_int_type=_get_layout_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), linear_idx_type=_get_index_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), masked=_tile_is_masked[pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), ((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()](), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]
OutMmaFragmentType
alias OutMmaFragmentType = LayoutTensor[OutType, pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]
type_alignment
alias type_alignment = align_of[SIMD[InType, simd_width_of[InType]()]]()
WK
alias WK = product(warp_block_layout_a.shape[1])
Methods
__init__
__init__(out self)
get_c_reg_tile_slice
get_c_reg_tile_slice(self, tile_idx: Int) -> LayoutTensor[OutType, LayoutTensor._compute_tile_layout[True, OutType, pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), _get_index_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), False, align_of[SIMD[InType, simd_width_of[InType]()]](), ((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()]()[0], MutableAnyOrigin, address_space=AddressSpace(5), layout_int_type=_get_layout_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), linear_idx_type=_get_index_type(pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), AddressSpace(5)), masked=_tile_is_masked[pipeline_layout[Layout.row_major(((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), tile_being_processed_per_warp](), ((product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0)) * (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()](), alignment=align_of[SIMD[InType, simd_width_of[InType]()]]()]
Returns:
mma
mma[swap_a_b: Bool = True](mut self, mut cache_manager: RingBuffer[SmemBufferTypeA, SmemBufferTypeB, consumer_warps], mut phase_a: Int, mut phase_b: Int, stage: Int, smem_warp_tile_idx_a: Int, smem_warp_tile_idx_b: Int, linear_warp_idx: Int, block_tile_num: Int)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!