Skip to main content

Mojo struct

TileLoaderLDS

struct TileLoaderLDS[dtype: DType, tile_rows: Int, tile_cols: Int, stride: Int, num_loading_warps: Int, swizzle: Optional[Swizzle] = Optional(), load_width: Int = simd_width_of[dtype](), use_full_tile_width: Bool = False]

Cooperative global→LDS loader via load_to_lds.

Cooperative global→LDS loader using AMDBufferResource.load_to_lds for direct DRAM→LDS DMA with OOB clamping.

Parameters

  • dtype (DType): Element data type.
  • tile_rows (Int): Height of each half-tile to load.
  • tile_cols (Int): Width (K dimension) of each half-tile.
  • stride (Int): Row stride of the source GMEM tensor.
  • num_loading_warps (Int): Warps cooperating on each load (typically 8).
  • swizzle (Optional): Optional byte-space swizzle for LDS bank conflicts.
  • load_width (Int): Elements per load (SIMD width).
  • use_full_tile_width (Bool): FP8 row-major mode.

Fields

  • buffer (AMDBufferResource):
  • thread_row (Int):
  • thread_col (Int):
  • warp_id (Int):
  • lane_id (Int):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

elements_per_warp

comptime elements_per_warp = (WARP_SIZE * load_width)

lane_load_bytes

comptime lane_load_bytes = (load_width * size_of[dtype]())

loading_threads

comptime loading_threads = (num_loading_warps * WARP_SIZE)

num_iterations

comptime num_iterations = (tile_rows // TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_iteration)

num_warp_cols

comptime num_warp_cols = (tile_cols // TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols)

num_warp_rows

comptime num_warp_rows = (num_loading_warps // TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].num_warp_cols)

row_bytes

comptime row_bytes = (tile_cols * size_of[dtype]())

rows_per_iteration

comptime rows_per_iteration = (TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].loading_threads // (tile_cols // load_width))

rows_per_warp

comptime rows_per_warp = (TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].elements_per_warp // tile_cols)

subtile_cols

comptime subtile_cols = tile_cols if use_full_tile_width else 32

thread_rows

comptime thread_rows = (WARP_SIZE // TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].threads_per_row)

threads_per_row

comptime threads_per_row = (TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].subtile_cols // load_width)

warp_subtile_bytes

comptime warp_subtile_bytes = ((TileLoaderLDS[dtype, tile_rows, tile_cols, stride, num_loading_warps, swizzle, load_width, use_full_tile_width].rows_per_warp * tile_cols) * size_of[dtype]())

Methods

__init__

__init__(src: TileTensor[dtype, src.LayoutType, src.origin], warp_id: Int, lane_id: Int) -> Self

Build from a GMEM tile (block-level A or B tile).

load_tile

load_tile(self, dst: TileTensor[dtype, dst.LayoutType, dst.origin, address_space=AddressSpace.SHARED], src_row: Int, src_col: Int)

Load from GMEM at (src_row, src_col) into SMEM dst via load_to_lds.

Args:

  • dst (TileTensor): Destination TileTensor in SHARED (half-tile sized).
  • src_row (Int): Row offset in the block's GMEM tile.
  • src_col (Int): Column (K) offset.

Was this page helpful?