Mojo module
tile_types
Native TileTensor types for SM100 structured kernels.
This module provides TileTensor-based tile types for SM100 structured kernels. All SMEM storage uses TileTensor natively. TileTensors are passed directly to TMA and MMA via TileTensor overloads.
Usage: from linalg.matmul.gpu.sm100_structured.structured_kernels.tile_types import ( SMemTile, SMemTileArray2D, SMemTileArrayWithLayout )
# Create tile with a layout
comptime my_layout = row_major[64, 32]()
comptime MyTile = SMemTile[DType.float16, my_layout]
# TileTensors are passed directly to TMA/MMA
tma_op.async_copy(tile, barrier, coords)comptime values
GMEMLayout1D
comptime GMEMLayout1D = Layout[RuntimeInt[DType.int64], ComptimeInt[1]]
1D layout for flat global memory arrays.
Shape is dynamic (RuntimeInt), stride is 1 (ComptimeInt[1]). Rank is provably 1 at compile time.
GMEMTile
comptime GMEMTile[dtype: DType, lt_layout: Layout] = TileTensor[dtype, Layout[#kgen.variadic.reduce<#kgen.variadic<> : !kgen.variadic<trait<@layout::@_coord::@CoordLike>>, #lit.struct.extract<:!lit.struct<@std::@builtin::@variadics::@VariadicList<:trait<@std::@builtin::@value::@TrivialRegisterPassable> @buffer::@dimlist::@Dim>> #lit.struct.extract<:!lit.struct<@buffer::@dimlist::@DimList> apply(:!lit.generator<[2]("val0": !lit.ref<!lit.struct<@buffer::@dimlist::@Dim>, imm #lit.comptime.origin> read_mem, "val1": !lit.ref<!lit.struct<@buffer::@dimlist::@Dim>, imm #lit.comptime.origin> read_mem) -> !lit.struct<@buffer::@dimlist::@DimList>> rebind(:!lit.generator<[2]("val0": !lit.ref<!lit.struct<@buffer::@dimlist::@Dim>, imm *[0,0]> read_mem, "val1": !lit.ref<!lit.struct<@buffer::@dimlist::@Dim>, imm *[0,1]> read_mem) -> !lit.struct<@buffer::@dimlist::@DimList>> @buffer::@dimlist::@DimList::@"__init__[::Indexer & ::Copyable,::Indexer & ::Copyable]($0,$1)"<:trait<@std::@builtin::@int::@Indexer, @std::@builtin::@value::@Copyable> @buffer::@dimlist::@Dim, :trait<@std::@builtin::@int::@Indexer, @std::@builtin::@value::@Copyable> @buffer::@dimlist::@Dim>), store_to_mem(apply(:!lit.generator<("value": !lit.struct<@std::@builtin::@int::@Int>) capturing -> !lit.struct<@buffer::@dimlist::@Dim>> @linalg::@matmul::@gpu::@sm100_structured::@structured_kernels::@tile_types::@"_int_to_dim(::Int)", apply(:!lit.generator<[1]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm #lit.comptime.origin> read_mem) -> !lit.struct<@std::@builtin::@int::@Int>> rebind(:!lit.generator<[1]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm *[0,0]> read_mem) -> !lit.struct<@std::@builtin::@int::@Int>> @layout::@int_tuple::@IntTuple::@"value(::IntTuple)"), store_to_mem(apply_result_slot(:!lit.generator<[2]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm #lit.comptime.origin> read_mem, "_idx": !lit.struct<@std::@builtin::@int::@Int>, ?, "__result__": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[2]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm *[0,0]> read_mem, "_idx": !lit.struct<@std::@builtin::@int::@Int>, ?, "__result__": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, mut *[0,1]> byref_result) -> !kgen.none> @layout::@int_tuple::@IntTuple::@"__getitem__(::IntTuple,::Int)"), store_to_mem(#lit.struct.extract<:!lit.struct<@layout::@layout::@Layout> lt_layout, "shape">), {0}))))), store_to_mem(apply(:!lit.generator<("value": !lit.struct<@std::@builtin::@int::@Int>) capturing -> !lit.struct<@buffer::@dimlist::@Dim>> @linalg::@matmul::@gpu::@sm100_structured::@structured_kernels::@tile_types::@"_int_to_dim(::Int)", apply(:!lit.generator<[1]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm #lit.comptime.origin> read_mem) -> !lit.struct<@std::@builtin::@int::@Int>> rebind(:!lit.generator<[1]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm *[0,0]> read_mem) -> !lit.struct<@std::@builtin::@int::@Int>> @layout::@int_tuple::@IntTuple::@"value(::IntTuple)"), store_to_mem(apply_result_slot(:!lit.generator<[2]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm #lit.comptime.origin> read_mem, "_idx": !lit.struct<@std::@builtin::@int::@Int>, ?, "__result__": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[2]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm *[0,0]> read_mem, "_idx": !lit.struct<@std::@builtin::@int::@Int>, ?, "__result__": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, mut *[0,1]> byref_result) -> !kgen.none> @layout::@int_tuple::@IntTuple::@"__getitem__(::IntTuple,::Int)"), store_to_mem(#lit.struct.extract<:!lit.struct<@layout::@layout::@Layout> lt_layout, "shape">), {1})))))), "value">, "value"> : !kgen.variadic<!lit.struct<@buffer::@dimlist::@Dim>>, #kgen.gen<#kgen.variadic.concat<#kgen.variadic<*(0,0), cond(sugar_preserved(#lit.struct.extract<:!lit.struct<@std::@builtin::@bool::@Bool> apply(:!lit.generator<("self": !lit.struct<@std::@builtin::@int::@Int>, "rhs": !lit.struct<@std::@builtin::@int::@Int>) -> !lit.struct<@std::@builtin::@bool::@Bool>> @std::@builtin::@int::@Int::@"__ne__(::Int,::Int)", #lit.struct.extract<:!lit.struct<@buffer::@dimlist::@Dim> variadic_get(:variadic<!lit.struct<@buffer::@dimlist::@Dim>> *(0,1), *(0,2)), "_value_or_missing">, {-31337}), "_mlir_value">, ne(#lit.struct.extract<:!lit.struct<@std::@builtin::@int::@Int> #lit.struct.extract<:!lit.struct<@buffer::@dimlist::@Dim> variadic_get(:variadic<!lit.struct<@buffer::@dimlist::@Dim>> *(0,1), *(0,2)), "_value_or_missing">, "_mlir_value">, -31337)), [@layout::@_coord::@ComptimeInt<:!lit.struct<@std::@builtin::@int::@Int> #lit.struct.extract<:!lit.struct<@buffer::@dimlist::@Dim> variadic_get(:variadic<!lit.struct<@buffer::@dimlist::@Dim>> *(0,1), *(0,2)), "_value_or_missing">>], [@layout::@_coord::@RuntimeInt<:!lit.struct<@std::@builtin::@dtype::@DType> {:dtype si64}>])> : !kgen.variadic<variadic<trait<@layout::@_coord::@CoordLike>>>>> : !kgen.generator<!lit.generator<<"PrevV": variadic<trait<@layout::@_coord::@CoordLike>>, "VA": variadic<!lit.struct<@buffer::@dimlist::@Dim>>, "idx": index>variadic<trait<@layout::@_coord::@CoordLike>>>>>, #kgen.variadic.reduce<#kgen.variadic<> : !kgen.variadic<trait<@layout::@_coord::@CoordLike>>, #lit.struct.extract<:!lit.struct<@std::@builtin::@variadics::@VariadicList<:trait<@std::@builtin::@value::@TrivialRegisterPassable> @buffer::@dimlist::@Dim>> #lit.struct.extract<:!lit.struct<@buffer::@dimlist::@DimList> apply(:!lit.generator<[2]("val0": !lit.ref<!lit.struct<@buffer::@dimlist::@Dim>, imm #lit.comptime.origin> read_mem, "val1": !lit.ref<!lit.struct<@buffer::@dimlist::@Dim>, imm #lit.comptime.origin> read_mem) -> !lit.struct<@buffer::@dimlist::@DimList>> rebind(:!lit.generator<[2]("val0": !lit.ref<!lit.struct<@buffer::@dimlist::@Dim>, imm *[0,0]> read_mem, "val1": !lit.ref<!lit.struct<@buffer::@dimlist::@Dim>, imm *[0,1]> read_mem) -> !lit.struct<@buffer::@dimlist::@DimList>> @buffer::@dimlist::@DimList::@"__init__[::Indexer & ::Copyable,::Indexer & ::Copyable]($0,$1)"<:trait<@std::@builtin::@int::@Indexer, @std::@builtin::@value::@Copyable> @buffer::@dimlist::@Dim, :trait<@std::@builtin::@int::@Indexer, @std::@builtin::@value::@Copyable> @buffer::@dimlist::@Dim>), store_to_mem(apply(:!lit.generator<("value": !lit.struct<@std::@builtin::@int::@Int>) capturing -> !lit.struct<@buffer::@dimlist::@Dim>> @linalg::@matmul::@gpu::@sm100_structured::@structured_kernels::@tile_types::@"_int_to_dim(::Int)", apply(:!lit.generator<[1]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm #lit.comptime.origin> read_mem) -> !lit.struct<@std::@builtin::@int::@Int>> rebind(:!lit.generator<[1]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm *[0,0]> read_mem) -> !lit.struct<@std::@builtin::@int::@Int>> @layout::@int_tuple::@IntTuple::@"value(::IntTuple)"), store_to_mem(apply_result_slot(:!lit.generator<[2]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm #lit.comptime.origin> read_mem, "_idx": !lit.struct<@std::@builtin::@int::@Int>, ?, "__result__": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[2]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm *[0,0]> read_mem, "_idx": !lit.struct<@std::@builtin::@int::@Int>, ?, "__result__": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, mut *[0,1]> byref_result) -> !kgen.none> @layout::@int_tuple::@IntTuple::@"__getitem__(::IntTuple,::Int)"), store_to_mem(#lit.struct.extract<:!lit.struct<@layout::@layout::@Layout> lt_layout, "stride">), {0}))))), store_to_mem(apply(:!lit.generator<("value": !lit.struct<@std::@builtin::@int::@Int>) capturing -> !lit.struct<@buffer::@dimlist::@Dim>> @linalg::@matmul::@gpu::@sm100_structured::@structured_kernels::@tile_types::@"_int_to_dim(::Int)", apply(:!lit.generator<[1]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm #lit.comptime.origin> read_mem) -> !lit.struct<@std::@builtin::@int::@Int>> rebind(:!lit.generator<[1]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm *[0,0]> read_mem) -> !lit.struct<@std::@builtin::@int::@Int>> @layout::@int_tuple::@IntTuple::@"value(::IntTuple)"), store_to_mem(apply_result_slot(:!lit.generator<[2]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm #lit.comptime.origin> read_mem, "_idx": !lit.struct<@std::@builtin::@int::@Int>, ?, "__result__": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[2]("self": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, imm *[0,0]> read_mem, "_idx": !lit.struct<@std::@builtin::@int::@Int>, ?, "__result__": !lit.ref<!lit.struct<@layout::@int_tuple::@IntTuple>, mut *[0,1]> byref_result) -> !kgen.none> @layout::@int_tuple::@IntTuple::@"__getitem__(::IntTuple,::Int)"), store_to_mem(#lit.struct.extract<:!lit.struct<@layout::@layout::@Layout> lt_layout, "stride">), {1})))))), "value">, "value"> : !kgen.variadic<!lit.struct<@buffer::@dimlist::@Dim>>, #kgen.gen<#kgen.variadic.concat<#kgen.variadic<*(0,0), cond(sugar_preserved(#lit.struct.extract<:!lit.struct<@std::@builtin::@bool::@Bool> apply(:!lit.generator<("self": !lit.struct<@std::@builtin::@int::@Int>, "rhs": !lit.struct<@std::@builtin::@int::@Int>) -> !lit.struct<@std::@builtin::@bool::@Bool>> @std::@builtin::@int::@Int::@"__ne__(::Int,::Int)", #lit.struct.extract<:!lit.struct<@buffer::@dimlist::@Dim> variadic_get(:variadic<!lit.struct<@buffer::@dimlist::@Dim>> *(0,1), *(0,2)), "_value_or_missing">, {-31337}), "_mlir_value">, ne(#lit.struct.extract<:!lit.struct<@std::@builtin::@int::@Int> #lit.struct.extract<:!lit.struct<@buffer::@dimlist::@Dim> variadic_get(:variadic<!lit.struct<@buffer::@dimlist::@Dim>> *(0,1), *(0,2)), "_value_or_missing">, "_mlir_value">, -31337)), [@layout::@_coord::@ComptimeInt<:!lit.struct<@std::@builtin::@int::@Int> #lit.struct.extract<:!lit.struct<@buffer::@dimlist::@Dim> variadic_get(:variadic<!lit.struct<@buffer::@dimlist::@Dim>> *(0,1), *(0,2)), "_value_or_missing">>], [@layout::@_coord::@RuntimeInt<:!lit.struct<@std::@builtin::@dtype::@DType> {:dtype si64}>])> : !kgen.variadic<variadic<trait<@layout::@_coord::@CoordLike>>>>> : !kgen.generator<!lit.generator<<"PrevV": variadic<trait<@layout::@_coord::@CoordLike>>, "VA": variadic<!lit.struct<@buffer::@dimlist::@Dim>>, "idx": index>variadic<trait<@layout::@_coord::@CoordLike>>>>>], MutAnyOrigin]
Global memory 2D TileTensor derived from a legacy Layout.
Used for kernel parameter types, replacing LayoutTensor parameters.
Parameters
internal_k_major
comptime internal_k_major[dtype: DType, BM: Int, BK: Int, swizzle_bytes: Int] = Layout(Coord(VariadicPack(Coord(VariadicPack(Idx[(BM // 8)](), Idx[8]())), Coord(VariadicPack(Idx[(swizzle_bytes // size_of[dtype]())](), Idx[((BK * size_of[dtype]()) // swizzle_bytes)]())))), Coord(VariadicPack(Coord(VariadicPack(Idx[(swizzle_bytes // size_of[dtype]())](), Idx[((BM // 8) * (swizzle_bytes // size_of[dtype]()))]())), Coord(VariadicPack(Idx[1](), Idx[0]())))))
Parameters
internal_k_major_128B
comptime internal_k_major_128B[dtype: DType, BM: Int, BK: Int] = Layout(Coord(VariadicPack(Coord(VariadicPack(Idx[(BM // 8)](), Idx[8]())), Coord(VariadicPack(Idx[(128 // size_of[dtype]())](), Idx[((BK * size_of[dtype]()) // 128)]())))), Coord(VariadicPack(Coord(VariadicPack(Idx[(128 // size_of[dtype]())](), Idx[((BM // 8) * (128 // size_of[dtype]()))]())), Coord(VariadicPack(Idx[1](), Idx[0]())))))
Parameters
internal_k_major_32B
comptime internal_k_major_32B[dtype: DType, BM: Int, BK: Int] = Layout(Coord(VariadicPack(Coord(VariadicPack(Idx[(BM // 8)](), Idx[8]())), Coord(VariadicPack(Idx[(32 // size_of[dtype]())](), Idx[((BK * size_of[dtype]()) // 32)]())))), Coord(VariadicPack(Coord(VariadicPack(Idx[(32 // size_of[dtype]())](), Idx[((BM // 8) * (32 // size_of[dtype]()))]())), Coord(VariadicPack(Idx[1](), Idx[0]())))))
Parameters
internal_k_major_64B
comptime internal_k_major_64B[dtype: DType, BM: Int, BK: Int] = Layout(Coord(VariadicPack(Coord(VariadicPack(Idx[(BM // 8)](), Idx[8]())), Coord(VariadicPack(Idx[(64 // size_of[dtype]())](), Idx[((BK * size_of[dtype]()) // 64)]())))), Coord(VariadicPack(Coord(VariadicPack(Idx[(64 // size_of[dtype]())](), Idx[((BM // 8) * (64 // size_of[dtype]()))]())), Coord(VariadicPack(Idx[1](), Idx[0]())))))
Parameters
internal_k_major_none
comptime internal_k_major_none[dtype: DType, BM: Int, BK: Int] = row_major[BM, BK]()
Parameters
internal_sf_k_major
comptime internal_sf_k_major[dim0: Int, dim1: Int] = Layout(Coord(VariadicPack(Coord(VariadicPack(Idx[32](), Idx[(dim0 // 32)]())), Coord(VariadicPack(Coord(VariadicPack(Idx[4](), Idx[4]())), Idx[(dim1 // 16)]())))), Coord(VariadicPack(Coord(VariadicPack(Idx[16](), Idx[(dim1 * 32)]())), Coord(VariadicPack(Coord(VariadicPack(Idx[1](), Idx[4]())), Idx[512]())))))
Parameters
SMemTile
comptime SMemTile[shape_types: Variadic[CoordLike], stride_types: Variadic[CoordLike], //, dtype: DType, layout: Layout[shape_types, stride_types], *, alignment: Int = 128] = TileTensor[dtype, Layout[shape_types, stride_types], MutAnyOrigin, address_space=AddressSpace.SHARED]
Shared memory tile using TileTensor with a Layout.
The Layout parameter preserves swizzle information, enabling .to_layout_tensor() to produce correctly swizzled LayoutTensors.
Parameters
SMemTile2D
comptime SMemTile2D[dtype: DType, dim0: Int, dim1: Int, *, alignment: Int = 128] = TileTensor[dtype, Layout[ComptimeInt[dim0], ComptimeInt[dim1], ComptimeInt[dim1], ComptimeInt[1]], MutAnyOrigin, address_space=AddressSpace.SHARED]
Backward-compatible alias for SMemTile with explicit 2D dimensions.
Parameters
SMemTileShape
comptime SMemTileShape[mut: Bool, dtype: DType, LayoutType: TensorLayout, origin: Origin[mut=mut], address_space: AddressSpace, linear_idx_type: DType, element_shape_types: Variadic[CoordLike], //, idx: Int, Tile: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types]] = LayoutType.static_shape[idx]
Get compile-time shape value at index from a TileTensor type.
Returns: The static shape value, or -1 if runtime-determined.
Parameters
- mut (
Bool): - dtype (
DType): - LayoutType (
TensorLayout): - origin (
Origin): - address_space (
AddressSpace): - linear_idx_type (
DType): - element_shape_types (
Variadic): - idx (
Int): The dimension index. - Tile (
TileTensor): The TileTensor type (use type_of(tile)).
SMemTileStride
comptime SMemTileStride[mut: Bool, dtype: DType, LayoutType: TensorLayout, origin: Origin[mut=mut], address_space: AddressSpace, linear_idx_type: DType, element_shape_types: Variadic[CoordLike], //, idx: Int, Tile: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types]] = LayoutType.static_stride[idx]
Get compile-time stride value at index from a TileTensor type.
Returns: The static stride value, or -1 if runtime-determined.
Parameters
- mut (
Bool): - dtype (
DType): - LayoutType (
TensorLayout): - origin (
Origin): - address_space (
AddressSpace): - linear_idx_type (
DType): - element_shape_types (
Variadic): - idx (
Int): The dimension index. - Tile (
TileTensor): The TileTensor type (use type_of(tile)).
static_row_major
comptime static_row_major[dim0: Int, dim1: Int] = Layout[ComptimeInt[dim0], ComptimeInt[dim1], ComptimeInt[dim1], ComptimeInt[1]]
2D row-major layout with fully static dimensions.
Equivalent to LegacyLayout.row_major(dim0, dim1) but using new Layout
types with rank=2 provable at compile time.
Parameters
swizzle_mode_to_bytes
comptime swizzle_mode_to_bytes[swizzle_mode: TensorMapSwizzle] = 128 if (swizzle_mode == TensorMapSwizzle.SWIZZLE_128B) else 64 if (swizzle_mode == TensorMapSwizzle.SWIZZLE_64B) else 32 if (swizzle_mode == TensorMapSwizzle.SWIZZLE_32B) else 0
Convert TensorMapSwizzle enum to swizzle size in bytes.
Returns: The swizzle size in bytes (128, 64, 32, or 0 for no swizzle).
Parameters
- swizzle_mode (
TensorMapSwizzle): The TensorMapSwizzle enum value.
tma_desc_layout_2d
comptime tma_desc_layout_2d[dtype: DType, tile_dim0: Int, swizzle: TensorMapSwizzle] = Layout[ComptimeInt[tile_dim0], ComptimeInt[(swizzle.bytes() // size_of[dtype]())], ComptimeInt[1], ComptimeInt[1]]
2D TMA descriptor layout: [dim0, swizzle_elems], strides [1, 1].
Parameters
- dtype (
DType): - tile_dim0 (
Int): - swizzle (
TensorMapSwizzle):
tma_desc_layout_3d
comptime tma_desc_layout_3d[dtype: DType, tile_dim0: Int, tile_dim1: Int, swizzle: TensorMapSwizzle] = Layout[ComptimeInt[tile_dim0], ComptimeInt[tile_dim1], ComptimeInt[(swizzle.bytes() // size_of[dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
3D TMA descriptor layout: [dim0, dim1, swizzle_elems], strides [1,1,1].
Parameters
- dtype (
DType): - tile_dim0 (
Int): - tile_dim1 (
Int): - swizzle (
TensorMapSwizzle):
tma_desc_layout_4d
comptime tma_desc_layout_4d[dtype: DType, tile_dim0: Int, tile_dim1: Int, tile_dim2: Int, swizzle: TensorMapSwizzle] = Layout[ComptimeInt[tile_dim0], ComptimeInt[tile_dim1], ComptimeInt[tile_dim2], ComptimeInt[(swizzle.bytes() // size_of[dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
4D TMA descriptor layout: [d0,d1,d2,swizzle_elems], strides all 1.
Parameters
- dtype (
DType): - tile_dim0 (
Int): - tile_dim1 (
Int): - tile_dim2 (
Int): - swizzle (
TensorMapSwizzle):
tma_desc_layout_5d
comptime tma_desc_layout_5d[dtype: DType, tile_dim0: Int, tile_dim1: Int, tile_dim2: Int, tile_dim3: Int, swizzle: TensorMapSwizzle] = Layout[ComptimeInt[tile_dim0], ComptimeInt[tile_dim1], ComptimeInt[tile_dim2], ComptimeInt[tile_dim3], ComptimeInt[(swizzle.bytes() // size_of[dtype]())], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1], ComptimeInt[1]]
5D TMA descriptor layout: [d0,d1,d2,d3,swizzle_elems], strides all 1.
Parameters
TmaOpType
comptime TmaOpType[dtype: DType, tile_layout: TensorLayout, desc_layout: TensorLayout] = TMATensorTile[dtype, _to_legacy_layout[tile_layout](), _to_legacy_layout[desc_layout]()]
TMATensorTile type derived from new Layout types.
Single source of truth: new Layout types determine the TMATensorTile type parameters via _to_legacy_layout.
Parameters
- dtype (
DType): - tile_layout (
TensorLayout): - desc_layout (
TensorLayout):
TmaOpTypeIm2col
comptime TmaOpTypeIm2col[dtype: DType, tile_layout: TensorLayout, desc_layout: TensorLayout] = TMATensorTileIm2col[dtype, _to_legacy_layout[tile_layout](), _to_legacy_layout[desc_layout]()]
TMATensorTileIm2col type derived from new Layout types.
Same as TmaOpType but for im2col TMA (used by conv2d activation loads).
Parameters
- dtype (
DType): - tile_layout (
TensorLayout): - desc_layout (
TensorLayout):
UnsafePointer
comptime UnsafePointer = LegacyUnsafePointer[?, address_space=?, origin=?]
Structs
-
BlockwiseFP8TilePayload: TileTensor-based tile payload for blockwise FP8 matmul. -
SMemTileArray: Array of TileTensor tiles with variadic shape/stride type parameters. -
SMemTileArray2D: Array of TileTensor tiles in shared memory with swizzled K-major layout. -
SMemTileArray2DRowMajor: Array of TileTensor tiles in shared memory with row_major layout. -
SMemTileArrayWithLayout: Array of TileTensor tiles with explicit Layout (preserves swizzle info). -
TMATile: TMA tile descriptor parameterized on new Layout types.
Functions
-
create_tma_tile: Create a TMATensorTile using new Layout types. -
lt_to_tt: Convert a 2D LayoutTensor to a TileTensor. -
lt_to_tt_1d: Convert a 1D LayoutTensor to a TileTensor with GMEMLayout1D.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!