Skip to main content

Mojo struct

TileLoaderLDS

struct TileLoaderLDS[dtype: DType, src_layout: Layout, src_tile_layout: Layout, num_loading_warps: Int, swizzle: Optional[Swizzle] = Optional(), load_width: Int = simd_width_of[dtype](), use_full_tile_width: Bool = False]

Cooperative global→LDS tile loader with swizzle support.

Loads tiles from global memory to LDS using AMDBufferResource which provides automatic out-of-bounds clamping to zero - critical for partial block support.

Loading Modes (controlled by use_full_tile_width):

  • False (default): Interleaved layout. Each warp handles 32-col subtile. Used for BF16 where MMA_K (32) < BK (64).
  • True: Row-major layout. Each source row maps 1:1 to LDS row. Used for FP8 where MMA_K == BK, enabling correct partial block handling.

Fields

  • buffer (AMDBufferResource):
  • thread_row (Int):
  • thread_col (Int):
  • warp_id (Int):
  • lane_id (Int):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

elements_per_warp

comptime elements_per_warp = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_warp * load_width)

lane_load_bytes

comptime lane_load_bytes = (load_width * size_of[dtype]())

loading_threads

comptime loading_threads = (num_loading_warps * TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_warp)

loads_per_row

comptime loads_per_row = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols // load_width)

num_iterations

comptime num_iterations = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_rows // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_iteration)

num_warp_cols

comptime num_warp_cols = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols)

num_warp_rows

comptime num_warp_rows = (num_loading_warps // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].num_warp_cols)

row_bytes

comptime row_bytes = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols * size_of[dtype]())

rows_per_iteration

comptime rows_per_iteration = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].loading_threads // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].loads_per_row)

rows_per_warp

comptime rows_per_warp = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].elements_per_warp // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols)

stride

comptime stride = src_layout.shape[1].value()

subtile_cols

comptime subtile_cols = TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols if use_full_tile_width else 32

thread_rows

comptime thread_rows = (WARP_SIZE // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_row)

threads_per_row

comptime threads_per_row = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols // load_width)

threads_per_warp

comptime threads_per_warp = WARP_SIZE

tile_cols

comptime tile_cols = src_tile_layout.shape[1].value()

tile_rows

comptime tile_rows = src_tile_layout.shape[0].value()

warp_subtile_bytes

comptime warp_subtile_bytes = ((TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_warp * TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols) * size_of[dtype]())

Methods

__init__

__init__(src: LayoutTensor[src.dtype, src.layout, src.origin, address_space=src.address_space, element_layout=src.element_layout, layout_int_type=src.layout_int_type, linear_idx_type=src.linear_idx_type, masked=src.masked, alignment=src.alignment], warp_id: Int, lane_id: Int) -> Self

Pre-compute thread position with swizzle inversion.

For BF16 (interleaved layout), the per-warp swizzle inversion computed here is exact because the swizzle pattern is translation-invariant over warp subtile boundaries. For FP8 (row-major layout), per-iteration computation is used instead (see load_tile).

load_tile

load_tile[dst_layout: Layout, //](self, dst: LayoutTensor[dtype, dst_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], src_row: Int, src_col: Int)

Load a tile from source coordinates to LDS.

Two paths depending on swizzle translation-invariance:

  1. Pre-computed path (BF16 / no swizzle): Uses thread_row/thread_col computed once in init. The BF16 Swizzle(1,5,4) is invariant over 1024-byte warp subtiles, so the per-warp approximation is exact.

  2. Per-iteration path (FP8 with swizzle): Computes the full byte offset within the half-tile each iteration. Required because FP8 Swizzle(3,4,4) and Swizzle(2,5,4) have their top source bit at the subtile boundary (bit 10 = log2(1024)), breaking invariance.

Was this page helpful?