For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
TileLoaderLDSIm2col
struct TileLoaderLDSIm2col[dtype_: DType, tile_rows_: Int, tile_cols_: Int, C: Int, num_loading_warps: Int, H: Int = 1, W: Int = 1, H_out: Int = 1, W_out: Int = 1, R: Int = 1, S: Int = 1, stride_h: Int = 1, stride_w: Int = 1, dilation_h: Int = 1, dilation_w: Int = 1, pad_h: Int = 0, pad_w: Int = 0, Q: Int = 1, D: Int = 1, D_out: Int = 1, stride_d: Int = 1, dilation_d: Int = 1, pad_d: Int = 0, swizzle: Optional[Swizzle] = Optional(), load_width: Int = simd_width_of[dtype_](), use_full_tile_width: Bool = False, use_runtime_hw: Bool = False]
DRAM->LDS DMA expert for implicit-GEMM convolution, NHWC inputs.
Sibling of TileLoaderLDS (linear GEMM source). Each iteration
issues one buffer_load_*_lds per lane, same vmcnt cost as the
matmul loader. The kernel's K-loop iterates flat k_offset β [0, R*S*C)
in steps of tile_cols; the loader internally decomposes
k_offset β (kh, kw, c_offset) and per-lane m_lane β (n, h_out, w_out),
then computes addr = ((n*H + h_in)*W + w_in)*C + c for each lane.
The body picks one of three comptime sub-paths at instantiation:
pure-pointwise (R=S=1, no pad β math collapses to m*C + k);
uniform-substrip (general RΓS with tile_cols β€ C and
C % tile_cols == 0 β one (kh, kw) per call); per-lane substrip
(otherwise β each lane decomposes its own k_lane). Pad > 0
additionally routes halo lanes (h_in or w_in outside [0, H)/[0, W))
to the SRD-OOB sentinel.
Parametersβ
- βdtype_ (
DType): Element data type. Re-bound todtypeat body scope to match theTileLoadertrait alias. - βtile_rows_ (
Int): Height of each half-tile to load (in M = NH_outW_out space). Re-bound totile_rowsat body scope. - βtile_cols_ (
Int): Width of each half-tile (in K = RSC space). Re-bound totile_colsat body scope. Must satisfytile_cols β€ CandC % tile_cols == 0so eachload_tilecall lives inside one (kh, kw) substrip. - βC (
Int): Input channel count. - βnum_loading_warps (
Int): Warps cooperating on each load. - βH (
Int): Input spatial height. - βW (
Int): Input spatial width. - βH_out (
Int): Output spatial height (with stride=1, dilation=1, no pad:H - R + 1). - βW_out (
Int): Output spatial width. - βR (
Int): Filter height. - βS (
Int): Filter width. - βstride_h (
Int): Vertical conv stride (>= 1). - βstride_w (
Int): Horizontal conv stride (>= 1). - βdilation_h (
Int): Vertical conv dilation (>= 1). - βdilation_w (
Int): Horizontal conv dilation (>= 1). - βpad_h (
Int): Vertical pad (>= 0). Halo lanes route to the SRD-OOB sentinel when pad > 0. - βpad_w (
Int): Horizontal pad (>= 0). - βQ (
Int): Filter temporal extent (3D-only).Q == 1(default) keeps the loader in 2D mode (4D NHWC input).Q > 1activates 3D mode (5D NDHWC input, K = QRS*C). - βD (
Int): Input temporal depth (3D-only; unused when Q == 1). - βD_out (
Int): Output temporal depth (3D-only). - βstride_d (
Int): Temporal conv stride (3D-only, >= 1). - βdilation_d (
Int): Temporal conv dilation (3D-only, >= 1). - βpad_d (
Int): Temporal pad (3D-only, >= 0). Halo lanes route to the SRD-OOB sentinel when pad_d > 0. - βswizzle (
Optional[Swizzle]): Optional byte-space swizzle for LDS bank conflicts. - βload_width (
Int): Elements per load (SIMD width). - βuse_full_tile_width (
Bool): FP8 row-major mode (matchesTileLoaderLDS.use_full_tile_width). - βuse_runtime_hw (
Bool): When True, H/W/H_out/W_out (and D/D_out in 3D mode) come from runtime constructor args instead of the comptime template params above. Used for graph-compiled callers with dynamic image resolution (e.g. FLUX VAE). The K-decomposition and conv params (Q, R, S, stride, dilation, pad) stay comptime.
Fieldsβ
- βbuffer (
AMDBufferResource): - βthread_row (
Int): - βthread_col (
Int): - βwarp_id (
Int): - βlane_id (
Int): - βnum_records (
Int): - βm_anchor (
Int): - βk_anchor (
Int): - βrt_h (
Int): - βrt_w (
Int): - βrt_h_out (
Int): - βrt_w_out (
Int): - βrt_spatial (
Int): - βrt_d (
Int): - βrt_d_out (
Int): - βrt_spatial_dhw (
Int):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TileLoader,
TrivialRegisterPassable
comptime membersβ
dtypeβ
comptime dtype = dtype_
elements_per_warpβ
comptime elements_per_warp = (WARP_SIZE * load_width)
lane_load_bytesβ
comptime lane_load_bytes = (load_width * size_of[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype]())
loading_threadsβ
comptime loading_threads = (num_loading_warps * WARP_SIZE)
num_iterationsβ
comptime num_iterations = ceildiv(TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_rows, TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].rows_per_iteration)
num_warp_colsβ
comptime num_warp_cols = (TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_cols // TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].subtile_cols)
num_warp_rowsβ
comptime num_warp_rows = (num_loading_warps // TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].num_warp_cols)
row_bytesβ
comptime row_bytes = (TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_cols * size_of[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype]())
rows_per_iterationβ
comptime rows_per_iteration = (TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].loading_threads // (TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_cols // load_width))
rows_per_warpβ
comptime rows_per_warp = (TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].elements_per_warp // TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_cols)
subtile_colsβ
comptime subtile_cols = TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_cols if use_full_tile_width else 32
thread_rowsβ
comptime thread_rows = (WARP_SIZE // TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].threads_per_row)
threads_per_rowβ
comptime threads_per_row = (TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].subtile_cols // load_width)
tile_colsβ
comptime tile_cols = tile_cols_
tile_rowsβ
comptime tile_rows = tile_rows_
total_warp_rowsβ
comptime total_warp_rows = ceildiv(TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_rows, TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].rows_per_warp)
warp_subtile_bytesβ
comptime warp_subtile_bytes = ((TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].rows_per_warp * TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_cols) * size_of[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype]())
Methodsβ
__init__β
__init__[InLayout: TensorLayout](src_nhwc: TileTensor[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype, InLayout], warp_id: Int, lane_id: Int, *, m_anchor: Int = 0, k_anchor: Int = 0) -> Self
Builds the loader from a 4D NHWC input TileTensor.
The SRD covers the entire NHWC tensor (N*H*W*C elements).
Per-block addressing is split between m_anchor/k_anchor
(per-block origin in GEMM space, set at construction) and the
m_offset/k_offset args of load_tile (within-block).
This overload is for the comptime-HW path (the default); the runtime conv geometry fields are populated with zeros and the loader uses the comptime template params instead.
Args:
- βsrc_nhwc (
TileTensor[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype, InLayout]): 4D NHWC input tensor of shape(N, H, W, C). - βwarp_id (
Int): Warp identifier within the loading warp group. - βlane_id (
Int): Lane identifier within the warp. - βm_anchor (
Int): M-coordinate (= flat NH_outW_out index) of the block origin. Added tom_offsetat load time. Defaults to 0 β pass per-block origin from the kernel. - βk_anchor (
Int): K-coordinate (= flat (kh, kw, c) index) of the block origin. Added tok_offsetat load time. Defaults to 0 β conv split-K is not yet supported, so callers typically leave this at the default.
__init__[InLayout: TensorLayout](src_nhwc: TileTensor[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype, InLayout], warp_id: Int, lane_id: Int, *, runtime_h: Int, runtime_w: Int, runtime_h_out: Int, runtime_w_out: Int, m_anchor: Int = 0, k_anchor: Int = 0) -> Self
Runtime-HW overload: H/W/H_out/W_out from runtime args.
Equivalent to the comptime-HW overload except the conv input
/ output spatial dims are runtime values (typically read from
input.dim() / output.dim() by the launcher). Use when the
graph compiler can't pin the resolution.
Args:
- βsrc_nhwc (
TileTensor[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype, InLayout]): 4D NHWC input tensor of shape(N, H, W, C). - βwarp_id (
Int): Warp identifier within the loading warp group. - βlane_id (
Int): Lane identifier within the warp. - βruntime_h (
Int): Runtime input height. - βruntime_w (
Int): Runtime input width. - βruntime_h_out (
Int): Runtime output height (must equal `(runtime_h- 2pad_h - dilation_h(R-1) - 1) // stride_h + 1`).
- βruntime_w_out (
Int): Runtime output width. - βm_anchor (
Int): M-coordinate of the block origin. Added tom_offsetat load time. Defaults to 0. - βk_anchor (
Int): K-coordinate of the block origin. Added tok_offsetat load time. Defaults to 0.
__init__[InLayout: TensorLayout](src_ndhwc: TileTensor[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype, InLayout], warp_id: Int, lane_id: Int, *, runtime_d: Int, runtime_h: Int, runtime_w: Int, runtime_d_out: Int, runtime_h_out: Int, runtime_w_out: Int, m_anchor: Int = 0, k_anchor: Int = 0) -> Self
3D runtime-HW overload: D/H/W/D_out/H_out/W_out from runtime args.
Equivalent to the 2D runtime-HW overload but for Q > 1 mode:
accepts a rank-5 NDHWC TileTensor and runtime D / D_out args
in addition to the spatial H/W ones. The K-decomposition and
conv params (Q, R, S, stride_d, stride_h, stride_w, dilation_,
pad_, C) stay comptime.
Args:
- βsrc_ndhwc (
TileTensor[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype, InLayout]): 5D NDHWC input tensor of shape(N, D, H, W, C). - βwarp_id (
Int): Warp identifier within the loading warp group. - βlane_id (
Int): Lane identifier within the warp. - βruntime_d (
Int): Runtime input depth. - βruntime_h (
Int): Runtime input height. - βruntime_w (
Int): Runtime input width. - βruntime_d_out (
Int): Runtime output depth (must equal(runtime_d + 2*pad_d - dilation_d*(Q-1) - 1) // stride_d + 1). - βruntime_h_out (
Int): Runtime output height. - βruntime_w_out (
Int): Runtime output width. - βm_anchor (
Int): M-coordinate of the block origin in GEMM space (= flat N*D_out*H_out*W_out index). Added tom_offsetat load time. Defaults to 0. - βk_anchor (
Int): K-coordinate of the block origin in GEMM space (= flat Q*R*S*C index). Added tok_offsetat load time. Defaults to 0.
load_tileβ
load_tile(self, dst: TileTensor[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype, address_space=AddressSpace.SHARED], m_offset: Int, k_offset: Int)
Loads a half-tile from NHWC global memory into the SMEM dst.
Two paths:
-
Pure-pointwise fast path (R=S=1, stride=1, dilation=1, pad=0): GEMM address
addr = ((n*H + h)*W + w)*C + ccollapses toaddr = m * C + k. Identical toTileLoaderLDS.load_tilewithstride = C; the per-lane vs uniform offset split is preserved so each iteration issues onebuffer_load_*_ldsper lane, matching the matmul's vmcnt accounting. -
General RΓS path (M2): the lane's
(m_lane, k_lane)are decomposed at runtime βk_lane β (kh, kw, c)via comptime R, S, C divisors (constant-folded to multiply-by-magic);m_lane β (n, h_out, w_out)via comptime H_out, W_out divisors. Thenh_in = h_out * stride_h + kh * dilation_h - pad_h(similarly for w_in) andaddr = ((n*H + h_in)*W + w_in)*C + c. The full per-lane address goes intovector_offset;scalar_offset = 0. Costs more VGPRs per load than the fast path because the address decomposition can't be cleanly split into a uniform + per-lane pair (the m β (n, h_out, w_out) decomposition is non-linear).
Args:
- βdst (
TileTensor[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype, address_space=AddressSpace.SHARED]): Destination half-tile in SHARED address space. - βm_offset (
Int): M-coordinate within the block (added toself.m_anchorto form the absolute GEMM M coord = flat NH_outW_out index). - βk_offset (
Int): K-coordinate within the block (added toself.k_anchorto form the absolute GEMM K coord = flat (kh, kw, c) index β [0, RSC)). Must be a multiple oftile_cols.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!