Mojo struct
TileLoaderLDS
struct TileLoaderLDS[dtype: DType, tile_rows: Int, tile_cols: Int, stride: Int, num_loading_warps: Int, swizzle: Optional[Swizzle] = Optional(), load_width: Int = simd_width_of[dtype](), use_full_tile_width: Bool = False]
Cooperative global→LDS loader via load_to_lds.
Cooperative global→LDS loader using AMDBufferResource.load_to_lds for direct DRAM→LDS DMA with OOB clamping.
Parameters
- dtype (
DType): Element data type. - tile_rows (
Int): Height of each half-tile to load. - tile_cols (
Int): Width (K dimension) of each half-tile. - stride (
Int): Row stride of the source GMEM tensor. - num_loading_warps (
Int): Warps cooperating on each load (typically 8). - swizzle (
Optional): Optional byte-space swizzle for LDS bank conflicts. - load_width (
Int): Elements per load (SIMD width). - use_full_tile_width (
Bool): FP8 row-major mode.
Fields
- buffer (
AMDBufferResource): - thread_row (
Int): - thread_col (
Int): - warp_id (
Int): - lane_id (
Int):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
elements_per_warp
comptime elements_per_warp = (WARP_SIZE * load_width)
lane_load_bytes
comptime lane_load_bytes = (load_width * size_of[dtype]())
loading_threads
comptime loading_threads = (num_loading_warps * WARP_SIZE)
num_iterations
comptime num_iterations = (tile_rows // TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_iteration)
num_warp_cols
comptime num_warp_cols = (tile_cols // TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols)
num_warp_rows
comptime num_warp_rows = (num_loading_warps // TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].num_warp_cols)
row_bytes
comptime row_bytes = (tile_cols * size_of[dtype]())
rows_per_iteration
comptime rows_per_iteration = (TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].loading_threads // (tile_cols // load_width))
rows_per_warp
comptime rows_per_warp = (TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].elements_per_warp // tile_cols)
subtile_cols
comptime subtile_cols = tile_cols if use_full_tile_width else 32
thread_rows
comptime thread_rows = (WARP_SIZE // TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_row)
threads_per_row
comptime threads_per_row = (TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols // load_width)
warp_subtile_bytes
comptime warp_subtile_bytes = ((TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_warp * tile_cols) * size_of[dtype]())
Methods
__init__
__init__(src: TileTensor[dtype, src.LayoutType, src.origin], warp_id: Int, lane_id: Int) -> Self
Build from a GMEM tile (block-level A or B tile).
load_tile
load_tile(self, dst: TileTensor[dtype, dst.LayoutType, dst.origin, address_space=AddressSpace.SHARED], src_row: Int, src_col: Int)
Load from GMEM at (src_row, src_col) into SMEM dst via load_to_lds.
Args:
- dst (
TileTensor): Destination TileTensor in SHARED (half-tile sized). - src_row (
Int): Row offset in the block's GMEM tile. - src_col (
Int): Column (K) offset.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!