Skip to main content

Mojo struct

SMemTileArray2D

@register_passable(trivial) struct SMemTileArray2D[dtype: DType, dim0: Int, dim1: Int, num_tiles: Int, swizzle_bytes: Int = 128, alignment: Int = 128]

Array of TileTensor tiles in shared memory with swizzled K-major layout.

The tiles use internal_k_major layout with configurable swizzle, matching the SM100 TMA swizzle pattern. This preserves swizzle information in the TileTensor type while using simple dimension-based parameters.

Note: For tiles without swizzle, use SMemTileArrayWithLayout with row_major.

Example: comptime MyArray = SMemTileArray2D[DType.float16, 64, 32, 4, 128, 128]

var array = MyArray.stack_allocation() var tile = array[0] # Returns TileTensor with swizzled layout

Parameters

  • dtype (DType): Tile element data type.
  • dim0 (Int): First dimension (rows, e.g., BM or BN).
  • dim1 (Int): Second dimension (columns, e.g., BK).
  • num_tiles (Int): Number of tiles in the array.
  • swizzle_bytes (Int): Swizzle size in bytes (128, 64, or 32). Must be > 0.
  • alignment (Int): Memory alignment (default 128 for shared memory).

Fields

  • ptr (LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

num_elements

comptime num_elements = (SMemTileArray2D[dtype, dim0, dim1, num_tiles, swizzle_bytes, alignment].tile_size * num_tiles)

Storage

comptime Storage = InlineArray[Scalar[dtype], SMemTileArray2D[dtype, dim0, dim1, num_tiles, swizzle_bytes, alignment].num_elements]

storage_size

comptime storage_size = (SMemTileArray2D[dtype, dim0, dim1, num_tiles, swizzle_bytes, alignment].num_elements * size_of[dtype]())

Tile

comptime Tile = TileTensor[dtype, Layout[Coord[ComptimeInt[(dim0 // 8)], ComptimeInt[8]], Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim1 * size_of[dtype]()) // swizzle_bytes)]], Coord[ComptimeInt[(swizzle_bytes // size_of[dtype]())], ComptimeInt[((dim0 // 8) * (swizzle_bytes // size_of[dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0]]], MutAnyOrigin, address_space=AddressSpace.SHARED]

tile_layout

comptime tile_layout = Layout(Coord(VariadicPack(Coord(VariadicPack(Idx[(dim0 // 8)](), Idx[8]())), Coord(VariadicPack(Idx[(swizzle_bytes // size_of[dtype]())](), Idx[((dim1 * size_of[dtype]()) // swizzle_bytes)]())))), Coord(VariadicPack(Coord(VariadicPack(Idx[(swizzle_bytes // size_of[dtype]())](), Idx[((dim0 // 8) * (swizzle_bytes // size_of[dtype]()))]())), Coord(VariadicPack(Idx[1](), Idx[0]())))))

tile_size

comptime tile_size = (dim0 * dim1)

Methods

__init__

__init__(ref[AddressSpace._value._mlir_value] storage: InlineArray[Scalar[dtype], SMemTileArray2D[dtype, dim0, dim1, num_tiles, swizzle_bytes, alignment].num_elements]) -> Self

Initialize from inline storage.

Args:

Returns:

Self: A new SMemTileArray2D pointing to the storage.

__init__[mut: Bool, //, origin: Origin[mut=mut]](unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, _mlir_origin=origin._mlir_origin, origin=origin]) -> Self

Initialize with a shared memory pointer.

Args:

__getitem__

__getitem__[T: Intable](self, index: T) -> SMemTileArray2D[dtype, dim0, dim1, num_tiles, swizzle_bytes, alignment].Tile

Get tile at the given index.

Args:

  • index (T): The tile index.

Returns:

SMemTileArray2D: A TileTensor-based tile at the given index with swizzled layout.

get_with_layout

get_with_layout[tile_layout: Layout[tile_layout.shape_types, tile_layout.stride_types], T: Intable](self, index: T) -> TileTensor[dtype, Layout[tile_layout.shape_types, tile_layout.stride_types], MutAnyOrigin, address_space=AddressSpace.SHARED]

Get tile at the given index with a specified layout.

This method allows getting tiles with a swizzled layout for MMA operations, where the layout information is needed for correct K-iteration offsets.

Parameters:

  • tile_layout (Layout): The layout to use (e.g., swizzled layout for MMA).
  • T (Intable): Index type (must be Intable).

Args:

  • index (T): The tile index.

Returns:

TileTensor: A TileTensor with the specified layout at the given index.

slice

slice[length: Int](self, start: Int) -> SMemTileArray2D[dtype, dim0, dim1, length, alignment]

Get a slice of the array.

Parameters:

  • length (Int): The length of the slice.

Args:

  • start (Int): The starting index.

Returns:

SMemTileArray2D: A new SMemTileArray2D representing the slice.

stack_allocation

static stack_allocation() -> Self

Allocate the array on the stack (in shared memory).

Returns:

Self: A new SMemTileArray2D backed by stack-allocated shared memory.

Was this page helpful?