IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

tile_types

Native TileTensor types for SM100 structured kernels.

This module provides TileTensor-based tile types for SM100 structured kernels. All SMEM storage uses TileTensor natively. TileTensors are passed directly to TMA and MMA via TileTensor overloads.

Usage:

# Create tile with a layout
comptime my_layout = row_major[64, 32]()
comptime MyTile = SMemTile[DType.float16, my_layout]

# TileTensors are passed directly to TMA/MMA
tma_op.async_copy(tile, barrier, coords)

comptime values​

GMEMLayout1D​

comptime GMEMLayout1D = Layout[*?, *?]

1D layout for flat global memory arrays.

Shape is dynamic (Scalar), stride is 1 (ComptimeInt[1]). Rank is provably 1 at compile time.

GMEMTile​

comptime GMEMTile[dtype: DType, lt_layout: Layout] = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin]

Global memory TileTensor derived from a legacy Layout.

Used for kernel parameter types, replacing LayoutTensor parameters.

Parameters​

internal_k_major​

comptime internal_k_major[dtype: DType, BM: Int, BK: Int, swizzle_bytes: Int] = Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())))

Parameters​

  • ​dtype (DType):
  • ​BM (Int):
  • ​BK (Int):
  • ​swizzle_bytes (Int):

internal_k_major_128B​

comptime internal_k_major_128B[dtype: DType, BM: Int, BK: Int] = Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())))

Parameters​

internal_k_major_32B​

comptime internal_k_major_32B[dtype: DType, BM: Int, BK: Int] = Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())))

Parameters​

internal_k_major_64B​

comptime internal_k_major_64B[dtype: DType, BM: Int, BK: Int] = Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt())))

Parameters​

internal_k_major_none​

comptime internal_k_major_none[dtype: DType, BM: Int, BK: Int] = row_major[BM, BK]()

Parameters​

internal_sf_k_major​

comptime internal_sf_k_major[dim0: Int, dim1: Int] = Layout(Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(Coord(ComptimeInt(), ComptimeInt()), ComptimeInt())), Coord(Coord(ComptimeInt(), ComptimeInt()), Coord(Coord(ComptimeInt(), ComptimeInt()), ComptimeInt())))

Parameters​

  • ​dim0 (Int):
  • ​dim1 (Int):

sf_tile_dim0​

comptime sf_tile_dim0[BM: Int] = ((BM // 128) * 32)

Parameters​

  • ​BM (Int):

sf_tile_dim1​

comptime sf_tile_dim1[sf_bk: Int, vec_sf_size: Int] = ((sf_bk // (4 * vec_sf_size)) * 16)

Parameters​

  • ​sf_bk (Int):
  • ​vec_sf_size (Int):

SMemTile​

comptime SMemTile[dtype: DType, layout: Layout[shape_types, stride_types], *, alignment: Int = 128] = TileTensor[dtype, Layout[shape_types, stride_types], MutAnyOrigin, address_space=AddressSpace.SHARED]

Shared memory tile using TileTensor with a Layout.

The Layout parameter preserves swizzle information, enabling .to_layout_tensor() to produce correctly swizzled LayoutTensors.

Parameters​

  • ​dtype (DType): The data type of tile elements.
  • ​layout (Layout[shape_types, stride_types]): The full layout including swizzle information.
  • ​alignment (Int): Memory alignment (default 128 for shared memory).

SMemTile2D​

comptime SMemTile2D[dtype: DType, dim0: Int, dim1: Int, *, alignment: Int = 128] = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

Backward-compatible alias for SMemTile with explicit 2D dimensions.

Parameters​

  • ​dtype (DType):
  • ​dim0 (Int):
  • ​dim1 (Int):
  • ​alignment (Int):

SMemTileShape​

comptime SMemTileShape[idx: Int, Tile: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_size=element_size]] = LayoutType.static_shape[idx]

Get compile-time shape value at index from a TileTensor type.

Returns: The static shape value, or -1 if runtime-determined.

Parameters​

SMemTileStride​

comptime SMemTileStride[idx: Int, Tile: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_size=element_size]] = LayoutType.static_stride[idx]

Get compile-time stride value at index from a TileTensor type.

Returns: The static stride value, or -1 if runtime-determined.

Parameters​

static_row_major​

comptime static_row_major[dim0: Int, dim1: Int] = Layout[*?, *?]

2D row-major layout with fully static dimensions.

Equivalent to LegacyLayout.row_major(dim0, dim1) but using new Layout types with rank=2 provable at compile time.

Parameters​

  • ​dim0 (Int):
  • ​dim1 (Int):

swizzle_mode_to_bytes​

comptime swizzle_mode_to_bytes[swizzle_mode: TensorMapSwizzle] = 128 if (swizzle_mode == TensorMapSwizzle.SWIZZLE_128B) else 64 if (swizzle_mode == TensorMapSwizzle.SWIZZLE_64B) else 32 if (swizzle_mode == TensorMapSwizzle.SWIZZLE_32B) else 16

Convert TensorMapSwizzle enum to swizzle size in bytes.

SWIZZLE_NONE returns 16, matching TensorMapSwizzle.bytes() (formula: (2**value) * 16, value=0 gives 16).

Returns: The swizzle size in bytes (16, 32, 64, or 128).

Parameters​

tma_desc_layout_2d​

comptime tma_desc_layout_2d[dtype: DType, tile_dim0: Int, swizzle: TensorMapSwizzle] = Layout[*?, *?]

2D TMA descriptor layout: [dim0, swizzle_elems], strides [1, 1].

Parameters​

tma_desc_layout_3d​

comptime tma_desc_layout_3d[dtype: DType, tile_dim0: Int, tile_dim1: Int, swizzle: TensorMapSwizzle] = Layout[*?, *?]

3D TMA descriptor layout: [dim0, dim1, swizzle_elems], strides [1,1,1].

Parameters​

tma_desc_layout_4d​

comptime tma_desc_layout_4d[dtype: DType, tile_dim0: Int, tile_dim1: Int, tile_dim2: Int, swizzle: TensorMapSwizzle] = Layout[*?, *?]

4D TMA descriptor layout: [d0,d1,d2,swizzle_elems], strides all 1.

Parameters​

tma_desc_layout_5d​

comptime tma_desc_layout_5d[dtype: DType, tile_dim0: Int, tile_dim1: Int, tile_dim2: Int, tile_dim3: Int, swizzle: TensorMapSwizzle] = Layout[*?, *?]

5D TMA descriptor layout: [d0,d1,d2,d3,swizzle_elems], strides all 1.

Parameters​

TmaOpType​

comptime TmaOpType[dtype: DType, tile_layout: TensorLayout, desc_layout: TensorLayout] = TMATensorTile[dtype, tile_layout.rank, _to_index_list[tile_layout](), _to_index_list[tile_layout.rank, desc_layout]()]

TMATensorTile type derived from new Layout types.

Single source of truth: new Layout types determine the TMATensorTile type parameters via _to_index_list.

Parameters​

TmaOpTypeIm2col​

comptime TmaOpTypeIm2col[dtype: DType, tile_layout: TensorLayout, desc_layout: TensorLayout] = TMATensorTileIm2col[dtype, tile_layout.rank, _to_index_list[tile_layout](), _to_index_list[tile_layout.rank, desc_layout]()]

TMATensorTileIm2col type derived from new Layout types.

Same as TmaOpType but for im2col TMA (used by conv2d activation loads).

Parameters​

Structs​

Traits​

  • ​TilePayload: Trait for tile payload types. Must be extend TrivialRegisterPassable.

Functions​