Skip to main content
Log in

Mojo struct

TensorCore

struct TensorCore[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool = False]

Layout reference => https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm80.hpp#L44.

Aliases

  • supported_fp32 = shape.__eq__(IndexList(16, 8, 8)) if is_nvidia_gpu() else shape.__eq__(IndexList(16, 16, 4)) if in_type.__eq__(float32) else in_type.__eq__(float32):
  • supported_half = shape.__eq__(IndexList(16, 8, 16)) if is_nvidia_gpu() else shape.__eq__(IndexList(16, 16, 16)) if in_type.is_half_float() else in_type.is_half_float():
  • supported_fp8 = shape.__eq__(IndexList(16, 8, 32)) if Tuple(VariadicPack(<store_to_mem({:dtype f8e4m3}), store_to_mem({:dtype f8e5m2})>, True)).__contains__[::EqualityComparableCollectionElement](in_type) else Tuple(VariadicPack(<store_to_mem({:dtype f8e4m3}), store_to_mem({:dtype f8e5m2})>, True)).__contains__[::EqualityComparableCollectionElement](in_type):
  • a_reg_type = SIMD[in_type, num_matrix_reg[::Int,::Int]()]:
  • a_reg_tile_type = LayoutTensor[in_type, col_major(1, num_matrix_reg[::Int,::Int]()), col_major(1, num_matrix_reg[::Int,::Int]()).rank(), address_space=5]:
  • b_reg_type = SIMD[in_type, num_matrix_reg[::Int,::Int]()]:
  • b_reg_tile_type = LayoutTensor[in_type, row_major(num_matrix_reg[::Int,::Int](), 1), row_major(num_matrix_reg[::Int,::Int](), 1).rank(), address_space=5]:
  • c_reg_type = SIMD[out_type, num_matrix_reg[::Int,::Int]()]:
  • c_reg_tile_type = LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), col_major(1, num_matrix_reg[::Int,::Int]()).rank(), address_space=5]:

Implemented traits

AnyType, UnknownDestructibility

Methods

__init__

__init__(out self)

get_shapes

static get_shapes[out_type: DType, in_type: DType]() -> List[IndexList[3]]

load_a

load_a(self, a: LayoutTensor[dtype, layout, rank, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, col_major(1, num_matrix_reg[::Int,::Int]()), col_major(1, num_matrix_reg[::Int,::Int]()).rank(), address_space=5]

load_a[swizzle: Bool = True, *, type0: DType, layout0: Layout, element_layout0: Layout, type1: DType, layout1: Layout, element_layout1: Layout](self, warp_tile: LayoutTensor[type0, layout0, layout0.rank(), address_space=3, element_layout=element_layout0, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[type1, layout1, layout1.rank(), address_space=5, element_layout=element_layout1, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = 0)

load_b

load_b(self, b: LayoutTensor[dtype, layout, rank, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, row_major(num_matrix_reg[::Int,::Int](), 1), row_major(num_matrix_reg[::Int,::Int](), 1).rank(), address_space=5]

load_b[*, type0: DType, layout0: Layout, element_layout0: Layout, layout1: Layout, element_layout1: Layout](self, warp_tile: LayoutTensor[type0, layout0, layout0.rank(), address_space=3, element_layout=element_layout0, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[type0, layout1, layout1.rank(), address_space=5, element_layout=element_layout1, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = 0, warp_tile_coord_n: UInt = 0)

load_b[*, type_b: DType, type0: DType, type_scales: DType, layout0: Layout, element_layout0: Layout, layout1: Layout, element_layout1: Layout, layout2: Layout, element_layout2: Layout](self, warp_tile: LayoutTensor[type_b, layout0, layout0.rank(), address_space=3, element_layout=element_layout0, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[type0, layout1, layout1.rank(), address_space=5, element_layout=element_layout1, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], scales: LayoutTensor[type_scales, layout2, layout2.rank(), address_space=5, element_layout=element_layout2, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = 0)

load_c

load_c(self, c: LayoutTensor[dtype, layout, rank, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), col_major(1, num_matrix_reg[::Int,::Int]()).rank(), address_space=5]

store_d

store_d(self, d_dst: LayoutTensor[dtype, layout, rank, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], d_src: LayoutTensor[dtype, layout, rank, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment])

mma_op

mma_op(self, a: LayoutTensor[dtype, layout, rank, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, rank, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c: LayoutTensor[dtype, layout, rank, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), col_major(1, num_matrix_reg[::Int,::Int]()).rank(), address_space=5]

mma

mma(self, a_frag: LayoutTensor[dtype, layout, rank, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b_frag: LayoutTensor[dtype, layout, rank, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c_frag: LayoutTensor[dtype, layout, rank, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment])