Skip to main content

Mojo struct

AmdTileOperator

struct AmdTileOperator[InType: DType, OutType: DType, warp_block_layout_a: Layout, warp_block_layout_b: Layout, mma_shape: IndexList[3], transpose_b: Bool, swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle](None), simd_width: Int = 1, warps_being_processed: Int = 1]

Fields

  • full_c_reg_tile (LayoutTensor[OutType, Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]):
  • a_reg_tile (LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]):
  • b_reg_tile (LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]):

Implemented traits

AnyType, UnknownDestructibility

Aliases

__del__is_trivial

alias __del__is_trivial = LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial if LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial if LayoutTensor[OutType, Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial else LayoutTensor[OutType, Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial else LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial if LayoutTensor[OutType, Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial else LayoutTensor[OutType, Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()].__del__is_trivial

a_matrix_size

alias a_matrix_size = (mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2))

b_matrix_size

alias b_matrix_size = (mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2))

fragments_per_simd_a

alias fragments_per_simd_a = (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))

fragments_per_simd_b

alias fragments_per_simd_b = (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))

full_out_mma_fragment_layout

alias full_out_mma_fragment_layout = Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]())

FullOutMmaFragmentType

alias FullOutMmaFragmentType = LayoutTensor[OutType, Layout.row_major(warps_being_processed, ((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]

in_layout

alias in_layout[num_mmas: Int, k_tiles_per_simd: Int] = Layout.row_major((k_tiles_per_simd * num_mmas), simd_width)

Parameters

  • num_mmas (Int):
  • k_tiles_per_simd (Int):

InMmaFragmentTypeA

alias InMmaFragmentTypeA = LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]

InMmaFragmentTypeB

alias InMmaFragmentTypeB = LayoutTensor[InType, Layout.row_major((((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size()))) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]

k_tiles_per_simd_a

alias k_tiles_per_simd_a = ((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size())))

k_tiles_per_simd_b

alias k_tiles_per_simd_b = ((warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2)) // (simd_width // ((mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size())))

num_k_tiles

alias num_k_tiles = (warp_block_layout_a.shape[1].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](2))

num_m_mmas

alias num_m_mmas = (warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0))

num_n_mmas

alias num_n_mmas = (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))

out_mma_fragment_layout

alias out_mma_fragment_layout = Layout.row_major(((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]())

OutMmaFragmentType

alias OutMmaFragmentType = LayoutTensor[OutType, Layout.row_major(((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]

register_count_a

alias register_count_a = ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size())

register_count_b

alias register_count_b = ((mma_shape.__getitem__[3, DType.int64, Int](1) * mma_shape.__getitem__[3, DType.int64, Int](2)) // _resolve_warp_size())

tensor_core_mma

alias tensor_core_mma = TensorCore[OutType, InType, mma_shape, transpose_b]()

type_alignment

alias type_alignment = align_of[SIMD[InType, simd_width]]()

WK

alias WK = warp_block_layout_a.shape[1].value[ComptimeOrigin]()

Methods

__init__

__init__(out self)

get_c_reg_tile_slice

get_c_reg_tile_slice(self, warp_idx: Int) -> LayoutTensor[OutType, Layout.row_major(((warp_block_layout_a.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](0)) * (warp_block_layout_b.shape[0].value[ComptimeOrigin]() // mma_shape.__getitem__[3, DType.int64, Int](1))), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[InType, simd_width]]()]

Returns:

LayoutTensor

mma

mma[swap_a_b: Bool = True](mut self, mut cache_manager: RingBuffer[pipeline_stages, a_tile_layout, b_tile_layout, TileTypeA, TileTypeB, WM, WN, WK, warps_per_block_m, warps_per_block_n], mut phase_a: Int, mut phase_b: Int, stage: Int, tile_idx_a: Int, tile_idx_b: Int, linear_warp_idx: Int)

Was this page helpful?