Mojo struct
TileLoaderLDS
struct TileLoaderLDS[dtype: DType, src_layout: Layout, src_tile_layout: Layout, num_loading_warps: Int, swizzle: Optional[Swizzle] = Optional(), load_width: Int = simd_width_of[dtype](), use_full_tile_width: Bool = False]
Cooperative global→LDS tile loader with swizzle support.
Loads tiles from global memory to LDS using AMDBufferResource which provides automatic out-of-bounds clamping to zero - critical for partial block support.
Loading Modes (controlled by use_full_tile_width):
- False (default): Interleaved layout. Each warp handles 32-col subtile. Used for BF16 where MMA_K (32) < BK (64).
- True: Row-major layout. Each source row maps 1:1 to LDS row. Used for FP8 where MMA_K == BK, enabling correct partial block handling.
Fields
- buffer (
AMDBufferResource): - thread_row (
Int): - thread_col (
Int): - warp_id (
Int): - lane_id (
Int):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
elements_per_warp
comptime elements_per_warp = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_warp * load_width)
lane_load_bytes
comptime lane_load_bytes = (load_width * size_of[dtype]())
loading_threads
comptime loading_threads = (num_loading_warps * TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_warp)
loads_per_row
comptime loads_per_row = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols // load_width)
num_iterations
comptime num_iterations = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_rows // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_iteration)
num_warp_cols
comptime num_warp_cols = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols)
num_warp_rows
comptime num_warp_rows = (num_loading_warps // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].num_warp_cols)
row_bytes
comptime row_bytes = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols * size_of[dtype]())
rows_per_iteration
comptime rows_per_iteration = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].loading_threads // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].loads_per_row)
rows_per_warp
comptime rows_per_warp = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].elements_per_warp // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols)
stride
comptime stride = src_layout.shape[1].value()
subtile_cols
comptime subtile_cols = TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols if use_full_tile_width else 32
thread_rows
comptime thread_rows = (WARP_SIZE // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_row)
threads_per_row
comptime threads_per_row = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols // load_width)
threads_per_warp
comptime threads_per_warp = WARP_SIZE
tile_cols
comptime tile_cols = src_tile_layout.shape[1].value()
tile_rows
comptime tile_rows = src_tile_layout.shape[0].value()
warp_subtile_bytes
comptime warp_subtile_bytes = ((TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_warp * TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols) * size_of[dtype]())
Methods
__init__
__init__(src: LayoutTensor[src.dtype, src.layout, src.origin, address_space=src.address_space, element_layout=src.element_layout, layout_int_type=src.layout_int_type, linear_idx_type=src.linear_idx_type, masked=src.masked, alignment=src.alignment], warp_id: Int, lane_id: Int) -> Self
Pre-compute thread position with swizzle inversion.
For BF16 (interleaved layout), the per-warp swizzle inversion computed here is exact because the swizzle pattern is translation-invariant over warp subtile boundaries. For FP8 (row-major layout), per-iteration computation is used instead (see load_tile).
load_tile
load_tile[dst_layout: Layout, //](self, dst: LayoutTensor[dtype, dst_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], src_row: Int, src_col: Int)
Load a tile from source coordinates to LDS.
Two paths depending on swizzle translation-invariance:
-
Pre-computed path (BF16 / no swizzle): Uses thread_row/thread_col computed once in init. The BF16 Swizzle(1,5,4) is invariant over 1024-byte warp subtiles, so the per-warp approximation is exact.
-
Per-iteration path (FP8 with swizzle): Computes the full byte offset within the half-tile each iteration. Required because FP8 Swizzle(3,4,4) and Swizzle(2,5,4) have their top source bit at the subtile boundary (bit 10 = log2(1024)), breaking invariance.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!