Skip to main content
Log in

Mojo struct

TensorCore

struct TensorCore[out_type: DType, in_type: DType, shape: Index[3], transpose_b: Bool = False]

Layout reference => https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm80.hpp#L44.

Aliases

  • supported_fp32 = (shape == IndexList(16, 8, 8)) if is_nvidia_gpu() else (shape == IndexList(16, 16, 4)) if (in_type == float32) else (in_type == float32):
  • supported_half = (shape == IndexList(16, 8, 16)) if is_nvidia_gpu() else (shape == IndexList(16, 16, 16)) if in_type.is_half_float() else in_type.is_half_float():
  • supported_fp8 = (shape == IndexList(16, 8, 32)) if Tuple(VariadicPack(float8_e4m3fn, float8_e5m2)).__contains__[::EqualityComparableCollectionElement](in_type) else Tuple(VariadicPack(float8_e4m3fn, float8_e5m2)).__contains__[::EqualityComparableCollectionElement](in_type):
  • a_reg_type = SIMD[in_type, num_matrix_reg[::Int,::Int]()]:
  • b_reg_type = SIMD[in_type, num_matrix_reg[::Int,::Int]()]:
  • c_reg_type = SIMD[out_type, num_matrix_reg[::Int,::Int]()]:
  • c_reg_tile_type = LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), address_space=AddressSpace(5)]:

Implemented traits

AnyType, UnknownDestructibility

Methods

__init__

__init__(out self)

get_shapes

static get_shapes[out_type: DType, in_type: DType]() -> List[Index[3]]

load_a

load_a[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1})](self, a: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, _get_a_reg_tile_layout[layout::layout::Layout,stdlib::utils::index::IndexList[::Int(), address_space=AddressSpace(5)]

load_a[swizzle: Bool = is_nvidia_gpu(), *, type0: DType, layout0: Layout, element_layout0: Layout, type1: DType, layout1: Layout, element_layout1: Layout](self, warp_tile: LayoutTensor[type0, layout0, mut=mut, origin=origin, address_space=AddressSpace(3), element_layout=element_layout0, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[type1, layout1, mut=mut, origin=origin, address_space=AddressSpace(5), element_layout=element_layout1, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = UInt(0))

load_b

load_b[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1})](self, b: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, _get_b_reg_tile_layout[layout::layout::Layout,stdlib::utils::index::IndexList[::Int(), address_space=AddressSpace(5)]

load_b[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1}), *, type0: DType, layout0: Layout, element_layout0: Layout, layout1: Layout, element_layout1: Layout](self, warp_tile: LayoutTensor[type0, layout0, mut=mut, origin=origin, address_space=AddressSpace(3), element_layout=element_layout0, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[type0, layout1, mut=mut, origin=origin, address_space=AddressSpace(5), element_layout=element_layout1, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = UInt(0), warp_tile_coord_n: UInt = UInt(0))

load_b[*, type_b: DType, type0: DType, type_scales: DType, layout0: Layout, element_layout0: Layout, layout1: Layout, element_layout1: Layout, layout2: Layout, element_layout2: Layout](self, warp_tile: LayoutTensor[type_b, layout0, mut=mut, origin=origin, address_space=AddressSpace(3), element_layout=element_layout0, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[type0, layout1, mut=mut, origin=origin, address_space=AddressSpace(5), element_layout=element_layout1, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], scales: LayoutTensor[type_scales, layout2, mut=mut, origin=origin, address_space=AddressSpace(5), element_layout=element_layout2, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = UInt(0))

load_c

load_c(self, c: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), address_space=AddressSpace(5)]

store_d

store_d(self, d_dst: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], d_src: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment])

mma_op

mma_op(self, a: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), address_space=AddressSpace(5)]

mma

mma(self, a_frag: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b_frag: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c_frag: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment])