Mojo struct
TileLoaderLDS
@register_passable(trivial)
struct TileLoaderLDS[dtype: DType, src_layout: Layout, src_tile_layout: Layout, num_loading_warps: Int, swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle](), load_width: Int = simd_width_of[dtype]()]
Encapsulates load_to_lds with pre-computed thread positions and swizzle.
Fields
- buffer (
AMDBufferResource): - thread_row (
Int): - thread_col (
Int): - warp_id (
Int):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
elements_per_warp
comptime elements_per_warp = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].threads_per_warp * load_width)
loading_threads
comptime loading_threads = (num_loading_warps * TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].threads_per_warp)
loads_per_row
comptime loads_per_row = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].tile_cols // load_width)
num_iterations
comptime num_iterations = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].tile_rows // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].rows_per_iteration)
rows_per_iteration
comptime rows_per_iteration = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].loading_threads // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].loads_per_row)
rows_per_warp
comptime rows_per_warp = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].elements_per_warp // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].tile_cols)
stride
comptime stride = src_layout.shape[1].value()
subtile_cols
comptime subtile_cols = 32
thread_rows
comptime thread_rows = (WARP_SIZE // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].threads_per_row)
threads_per_row
comptime threads_per_row = (32 // load_width)
threads_per_warp
comptime threads_per_warp = WARP_SIZE
tile_cols
comptime tile_cols = src_tile_layout.shape[1].value()
tile_rows
comptime tile_rows = src_tile_layout.shape[0].value()
Methods
__init__
__init__(src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_id: Int, lane_id: Int) -> Self
Pre-compute thread position with swizzle inversion for bank-conflict-free reads.
load_tile
load_tile[dst_layout: Layout, //](self, dst: LayoutTensor[dtype, dst_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_row: Int, src_col: Int)
Load a tile from source coordinates to LDS.
Combines pre-computed thread position with source coordinates. Uses the buffer resource stored at init time. Only warps 0 to num_loading_warps-1 participate; others return immediately.
Args:
- dst (
LayoutTensor): Destination LDS tile. - src_row (
Int): Starting row in source tensor. - src_col (
Int): Starting column in source tensor (typically k_offset).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!