IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TileLoaderLDS

struct TileLoaderLDS[dtype_: DType, tile_rows_: Int, tile_cols_: Int, stride: Int, num_loading_warps: Int, swizzle: Optional[Swizzle] = Optional(), load_width: Int = simd_width_of[dtype_](), use_full_tile_width: Bool = False]

DRAM→LDS DMA expert for warp-group cooperative coord-indexed loads.

Sibling of SubTileLoaderLDS (single-sub-tile TileTensor-indexed). This one coordinates a warp group (typically 8 warps) to cooperatively fill a half-tile via coord-indexed iteration: load_tile(dst, m_offset, k_offset) steps through num_iterations BK-wide rows, optionally applying a per-iteration byte-space swizzle for LDS bank-conflict avoidance. Matmul's DRAM→LDS pattern (ping-pong, etc.).

Uses stdlib AMDBufferResource.load_to_lds directly β€” no alias scope attached. Matmul's scheduling uses s_sched_group_barrier hints, which don't qualify as the runtime fence required by the SIInsertWaitcnts vmcnt-relaxation contract; attaching the scope would miscompile (see the async_copies docstring on load_to_lds). For attention patterns that DO satisfy the contract via explicit s_waitcnt vmcnt(0) + s_barrier fences, use SubTileLoaderLDS instead.

Parameters​

  • ​dtype_ (DType): Element data type. Re-bound to dtype at body scope to match the TileLoader trait alias.
  • ​tile_rows_ (Int): Height of each half-tile to load. Re-bound to tile_rows at body scope.
  • ​tile_cols_ (Int): Width (K dimension) of each half-tile. Re-bound to tile_cols at body scope.
  • ​stride (Int): Row stride of the source GMEM tensor.
  • ​num_loading_warps (Int): Warps cooperating on each load (typically 8).
  • ​swizzle (Optional[Swizzle]): Optional byte-space swizzle for LDS bank conflicts.
  • ​load_width (Int): Elements per load (SIMD width).
  • ​use_full_tile_width (Bool): FP8 row-major mode.

Fields​

  • ​buffer (AMDBufferResource):
  • ​thread_row (Int):
  • ​thread_col (Int):
  • ​warp_id (Int):
  • ​lane_id (Int):
  • ​m_anchor (Int):
  • ​k_anchor (Int):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TileLoader, TrivialRegisterPassable

comptime members​

dtype​

comptime dtype = dtype_

elements_per_warp​

comptime elements_per_warp = (WARP_SIZE * load_width)

lane_load_bytes​

comptime lane_load_bytes = (load_width * size_of[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype]())

loading_threads​

comptime loading_threads = (num_loading_warps * WARP_SIZE)

num_iterations​

comptime num_iterations = ceildiv(TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_rows, TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_iteration)

num_warp_cols​

comptime num_warp_cols = (TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols // TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols)

num_warp_rows​

comptime num_warp_rows = (num_loading_warps // TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].num_warp_cols)

row_bytes​

comptime row_bytes = (TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols * size_of[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype]())

rows_per_iteration​

comptime rows_per_iteration = (TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].loading_threads // (TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols // load_width))

rows_per_warp​

comptime rows_per_warp = (TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].elements_per_warp // TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols)

subtile_cols​

comptime subtile_cols = TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols if use_full_tile_width else 32

thread_rows​

comptime thread_rows = (WARP_SIZE // TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_row)

threads_per_row​

comptime threads_per_row = (TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols // load_width)

tile_cols​

comptime tile_cols = tile_cols_

tile_rows​

comptime tile_rows = tile_rows_

total_warp_rows​

comptime total_warp_rows = ceildiv(TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_rows, TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_warp)

warp_subtile_bytes​

comptime warp_subtile_bytes = ((TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_warp * TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols) * size_of[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype]())

Methods​

__init__​

__init__(src: TileTensor[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype], warp_id: Int, lane_id: Int, *, m_anchor: Int = 0, k_anchor: Int = 0) -> Self

Builds the loader.

Args:

  • ​src (TileTensor[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype]): GMEM tile to source from. Pass the full A/B tensor and set m_anchor/k_anchor to the per-block origin, or pass a pre-sliced block tile with zero anchors (legacy behavior). The full-tensor form lets the SRD's num_records bound the actual allocation rather than the block view β€” required for split-K kernels and for parity with TileLoaderLDSIm2col.
  • ​warp_id (Int): Warp identifier within the loading warp group.
  • ​lane_id (Int): Lane identifier within the warp.
  • ​m_anchor (Int): M-coordinate (row dim) of the block origin in the loader's SRD coordinate system. Added to m_offset at load time. Defaults to 0.
  • ​k_anchor (Int): K-coordinate (column dim) of the block origin. Added to k_offset at load time. Defaults to 0.

load_tile​

load_tile(self, dst: TileTensor[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype, address_space=AddressSpace.SHARED], m_offset: Int, k_offset: Int)

Loads a half-tile from GMEM into SMEM dst via load_to_lds.

The effective GEMM-space coordinate is (m_anchor + m_offset, k_anchor + k_offset), so callers using the legacy pre-sliced-block form (anchors=0) keep their address math unchanged.

Args: