Mojo struct
AmdTileOperator
struct AmdTileOperator[InType: DType, OutType: DType, warp_block_layout_a: Layout, warp_block_layout_b: Layout, mma_shape: IndexList[3], transpose_b: Bool, swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle](None), simd_width: Int = 1, warps_being_processed: Int = 1]
Fields
- full_c_reg_tile (
LayoutTensor[OutType, Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]
): - a_reg_tile (
LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]
): - b_reg_tile (
LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]
):
Implemented traits
AnyType
,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial if LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial if LayoutTensor[OutType, Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial else LayoutTensor[OutType, Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial else LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial if LayoutTensor[OutType, Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial else LayoutTensor[OutType, Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial
a_matrix_size
alias a_matrix_size = (mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2))
b_matrix_size
alias b_matrix_size = (mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2))
fragments_per_simd_a
alias fragments_per_simd_a = (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))
fragments_per_simd_b
alias fragments_per_simd_b = (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))
full_out_mma_fragment_layout
alias full_out_mma_fragment_layout = Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]())
FullOutMmaFragmentType
alias FullOutMmaFragmentType = LayoutTensor[OutType, Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]
in_layout
alias in_layout[num_mmas: Int, k_tiles_per_simd: Int] = Layout.row_major((k_tiles_per_simd * num_mmas), simd_width)
Parameters
InMmaFragmentTypeA
alias InMmaFragmentTypeA = LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]
InMmaFragmentTypeB
alias InMmaFragmentTypeB = LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]
k_tiles_per_simd_a
alias k_tiles_per_simd_a = ((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size())))
k_tiles_per_simd_b
alias k_tiles_per_simd_b = ((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size())))
num_k_tiles
alias num_k_tiles = (warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2))
num_m_mmas
alias num_m_mmas = (warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0))
num_n_mmas
alias num_n_mmas = (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))
out_mma_fragment_layout
alias out_mma_fragment_layout = Layout.row_major(((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]())
OutMmaFragmentType
alias OutMmaFragmentType = LayoutTensor[OutType, Layout.row_major(((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]
register_count_a
alias register_count_a = ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size())
register_count_b
alias register_count_b = ((mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size())
tensor_core_mma
alias tensor_core_mma = TensorCore[OutType, InType, mma_shape, transpose_b]()
type_alignment
alias type_alignment = align_of[SIMD[InType, simd_width]]()
WK
alias WK = warp_block_layout_a.shape[1].value[ComptimeOrigin]()
Methods
__init__
__init__(out self)
get_c_reg_tile_slice
get_c_reg_tile_slice(self, warp_idx: Int) -> LayoutTensor[OutType, Layout.row_major(((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]
Returns:
mma
mma[swap_a_b: Bool = True](mut self, mut cache_manager: RingBuffer[pipeline_stages, a_tile_layout, b_tile_layout, TileTypeA, TileTypeB, WM, WN, WK, warps_per_block_m, warps_per_block_n], mut phase_a: Int, mut phase_b: Int, stage: Int, tile_idx_a: Int, tile_idx_b: Int, linear_warp_idx: Int)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!