For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
TileLoaderLDS
struct TileLoaderLDS[dtype_: DType, tile_rows_: Int, tile_cols_: Int, stride: Int, num_loading_warps: Int, swizzle: Optional[Swizzle] = Optional(), load_width: Int = simd_width_of[dtype_](), use_full_tile_width: Bool = False]
DRAMβLDS DMA expert for warp-group cooperative coord-indexed loads.
Sibling of SubTileLoaderLDS (single-sub-tile TileTensor-indexed).
This one coordinates a warp group (typically 8 warps) to cooperatively
fill a half-tile via coord-indexed iteration: load_tile(dst, m_offset, k_offset) steps through num_iterations BK-wide rows,
optionally applying a per-iteration byte-space swizzle for LDS
bank-conflict avoidance. Matmul's DRAMβLDS pattern (ping-pong, etc.).
Uses stdlib AMDBufferResource.load_to_lds directly β no alias scope
attached. Matmul's scheduling uses s_sched_group_barrier hints,
which don't qualify as the runtime fence required by the
SIInsertWaitcnts vmcnt-relaxation contract; attaching the scope
would miscompile (see the async_copies docstring on
load_to_lds). For attention patterns that DO satisfy the contract
via explicit s_waitcnt vmcnt(0) + s_barrier fences, use
SubTileLoaderLDS instead.
Parametersβ
- βdtype_ (
DType): Element data type. Re-bound todtypeat body scope to match theTileLoadertrait alias. - βtile_rows_ (
Int): Height of each half-tile to load. Re-bound totile_rowsat body scope. - βtile_cols_ (
Int): Width (K dimension) of each half-tile. Re-bound totile_colsat body scope. - βstride (
Int): Row stride of the source GMEM tensor. - βnum_loading_warps (
Int): Warps cooperating on each load (typically 8). - βswizzle (
Optional[Swizzle]): Optional byte-space swizzle for LDS bank conflicts. - βload_width (
Int): Elements per load (SIMD width). - βuse_full_tile_width (
Bool): FP8 row-major mode.
Fieldsβ
- βbuffer (
AMDBufferResource): - βthread_row (
Int): - βthread_col (
Int): - βwarp_id (
Int): - βlane_id (
Int): - βm_anchor (
Int): - βk_anchor (
Int):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TileLoader,
TrivialRegisterPassable
comptime membersβ
dtypeβ
comptime dtype = dtype_
elements_per_warpβ
comptime elements_per_warp = (WARP_SIZE * load_width)
lane_load_bytesβ
comptime lane_load_bytes = (load_width * size_of[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype]())
loading_threadsβ
comptime loading_threads = (num_loading_warps * WARP_SIZE)
num_iterationsβ
comptime num_iterations = ceildiv(TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_rows, TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_iteration)
num_warp_colsβ
comptime num_warp_cols = (TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols // TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols)
num_warp_rowsβ
comptime num_warp_rows = (num_loading_warps // TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].num_warp_cols)
row_bytesβ
comptime row_bytes = (TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols * size_of[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype]())
rows_per_iterationβ
comptime rows_per_iteration = (TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].loading_threads // (TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols // load_width))
rows_per_warpβ
comptime rows_per_warp = (TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].elements_per_warp // TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols)
subtile_colsβ
comptime subtile_cols = TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols if use_full_tile_width else 32
thread_rowsβ
comptime thread_rows = (WARP_SIZE // TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_row)
threads_per_rowβ
comptime threads_per_row = (TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols // load_width)
tile_colsβ
comptime tile_cols = tile_cols_
tile_rowsβ
comptime tile_rows = tile_rows_
total_warp_rowsβ
comptime total_warp_rows = ceildiv(TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_rows, TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_warp)
warp_subtile_bytesβ
comptime warp_subtile_bytes = ((TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_warp * TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].tile_cols) * size_of[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype]())
Methodsβ
__init__β
__init__(src: TileTensor[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype], warp_id: Int, lane_id: Int, *, m_anchor: Int = 0, k_anchor: Int = 0) -> Self
Builds the loader.
Args:
- βsrc (
TileTensor[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype]): GMEM tile to source from. Pass the full A/B tensor and setm_anchor/k_anchorto the per-block origin, or pass a pre-sliced block tile with zero anchors (legacy behavior). The full-tensor form lets the SRD'snum_recordsbound the actual allocation rather than the block view β required for split-K kernels and for parity withTileLoaderLDSIm2col. - βwarp_id (
Int): Warp identifier within the loading warp group. - βlane_id (
Int): Lane identifier within the warp. - βm_anchor (
Int): M-coordinate (row dim) of the block origin in the loader's SRD coordinate system. Added tom_offsetat load time. Defaults to 0. - βk_anchor (
Int): K-coordinate (column dim) of the block origin. Added tok_offsetat load time. Defaults to 0.
load_tileβ
load_tile(self, dst: TileTensor[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype, address_space=AddressSpace.SHARED], m_offset: Int, k_offset: Int)
Loads a half-tile from GMEM into SMEM dst via load_to_lds.
The effective GEMM-space coordinate is (m_anchor + m_offset, k_anchor + k_offset), so callers using the legacy
pre-sliced-block form (anchors=0) keep their address math
unchanged.
Args:
- βdst (
TileTensor[TileLoaderLDS[dtype_, tile_rows_, tile_cols_, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].dtype, address_space=AddressSpace.SHARED]): Destination TileTensor in SHARED (half-tile sized). - βm_offset (
Int): Row offset (M dim) within the block. - βk_offset (
Int): Column (K dim) offset within the block.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!