Mojo struct
TensorCore
struct TensorCore[out_type: DType, in_type: DType, shape: Index[3], transpose_b: Bool = False]
Layout reference => https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm80.hpp#L44.
Aliases
supported_fp32 = (shape == IndexList(16, 8, 8)) if is_nvidia_gpu() else (shape == IndexList(16, 16, 4)) if (in_type == float32) else (in_type == float32)
:supported_half = (shape == IndexList(16, 8, 16)) if is_nvidia_gpu() else (shape == IndexList(16, 16, 16)) if in_type.is_half_float() else in_type.is_half_float()
:supported_fp8 = (shape == IndexList(16, 8, 32)) if Tuple(VariadicPack(float8_e4m3fn, float8_e5m2)).__contains__[::EqualityComparableCollectionElement](in_type) else Tuple(VariadicPack(float8_e4m3fn, float8_e5m2)).__contains__[::EqualityComparableCollectionElement](in_type)
:a_reg_type = SIMD[in_type, num_matrix_reg[::Int,::Int]()]
:b_reg_type = SIMD[in_type, num_matrix_reg[::Int,::Int]()]
:c_reg_type = SIMD[out_type, num_matrix_reg[::Int,::Int]()]
:c_reg_tile_type = LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), address_space=AddressSpace(5)]
:
Implemented traits
AnyType
,
UnknownDestructibility
Methods
__init__
__init__(out self)
get_shapes
static get_shapes[out_type: DType, in_type: DType]() -> List[Index[3]]
load_a
load_a[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1})](self, a: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, _get_a_reg_tile_layout[layout::layout::Layout,stdlib::utils::index::IndexList[::Int(), address_space=AddressSpace(5)]
load_a[swizzle: Bool = is_nvidia_gpu(), *, type0: DType, layout0: Layout, element_layout0: Layout, type1: DType, layout1: Layout, element_layout1: Layout](self, warp_tile: LayoutTensor[type0, layout0, mut=mut, origin=origin, address_space=AddressSpace(3), element_layout=element_layout0, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[type1, layout1, mut=mut, origin=origin, address_space=AddressSpace(5), element_layout=element_layout1, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = UInt(0))
load_b
load_b[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1})](self, b: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, _get_b_reg_tile_layout[layout::layout::Layout,stdlib::utils::index::IndexList[::Int(), address_space=AddressSpace(5)]
load_b[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1}), *, type0: DType, layout0: Layout, element_layout0: Layout, layout1: Layout, element_layout1: Layout](self, warp_tile: LayoutTensor[type0, layout0, mut=mut, origin=origin, address_space=AddressSpace(3), element_layout=element_layout0, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[type0, layout1, mut=mut, origin=origin, address_space=AddressSpace(5), element_layout=element_layout1, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = UInt(0), warp_tile_coord_n: UInt = UInt(0))
load_b[*, type_b: DType, type0: DType, type_scales: DType, layout0: Layout, element_layout0: Layout, layout1: Layout, element_layout1: Layout, layout2: Layout, element_layout2: Layout](self, warp_tile: LayoutTensor[type_b, layout0, mut=mut, origin=origin, address_space=AddressSpace(3), element_layout=element_layout0, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], fragments: LayoutTensor[type0, layout1, mut=mut, origin=origin, address_space=AddressSpace(5), element_layout=element_layout1, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], scales: LayoutTensor[type_scales, layout2, mut=mut, origin=origin, address_space=AddressSpace(5), element_layout=element_layout2, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = UInt(0))
load_c
load_c(self, c: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), address_space=AddressSpace(5)]
store_d
store_d(self, d_dst: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], d_src: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment])
mma_op
mma_op(self, a: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment]) -> LayoutTensor[out_type, col_major(1, num_matrix_reg[::Int,::Int]()), address_space=AddressSpace(5)]
mma
mma(self, a_frag: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], b_frag: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment], c_frag: LayoutTensor[dtype, layout, mut=mut, origin=origin, address_space=address_space, element_layout=element_layout, layout_bitwidth=layout_bitwidth, masked=masked, alignment=alignment])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!