IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TileLoaderLDSIm2col

struct TileLoaderLDSIm2col[dtype_: DType, tile_rows_: Int, tile_cols_: Int, C: Int, num_loading_warps: Int, H: Int = 1, W: Int = 1, H_out: Int = 1, W_out: Int = 1, R: Int = 1, S: Int = 1, stride_h: Int = 1, stride_w: Int = 1, dilation_h: Int = 1, dilation_w: Int = 1, pad_h: Int = 0, pad_w: Int = 0, Q: Int = 1, D: Int = 1, D_out: Int = 1, stride_d: Int = 1, dilation_d: Int = 1, pad_d: Int = 0, swizzle: Optional[Swizzle] = Optional(), load_width: Int = simd_width_of[dtype_](), use_full_tile_width: Bool = False, use_runtime_hw: Bool = False]

DRAM->LDS DMA expert for implicit-GEMM convolution, NHWC inputs.

Sibling of TileLoaderLDS (linear GEMM source). Each iteration issues one buffer_load_*_lds per lane, same vmcnt cost as the matmul loader. The kernel's K-loop iterates flat k_offset ∈ [0, R*S*C) in steps of tile_cols; the loader internally decomposes k_offset β†’ (kh, kw, c_offset) and per-lane m_lane β†’ (n, h_out, w_out), then computes addr = ((n*H + h_in)*W + w_in)*C + c for each lane.

The body picks one of three comptime sub-paths at instantiation: pure-pointwise (R=S=1, no pad β€” math collapses to m*C + k); uniform-substrip (general RΓ—S with tile_cols ≀ C and C % tile_cols == 0 β€” one (kh, kw) per call); per-lane substrip (otherwise β€” each lane decomposes its own k_lane). Pad > 0 additionally routes halo lanes (h_in or w_in outside [0, H)/[0, W)) to the SRD-OOB sentinel.

Parameters​

  • ​dtype_ (DType): Element data type. Re-bound to dtype at body scope to match the TileLoader trait alias.
  • ​tile_rows_ (Int): Height of each half-tile to load (in M = NH_outW_out space). Re-bound to tile_rows at body scope.
  • ​tile_cols_ (Int): Width of each half-tile (in K = RSC space). Re-bound to tile_cols at body scope. Must satisfy tile_cols ≀ C and C % tile_cols == 0 so each load_tile call lives inside one (kh, kw) substrip.
  • ​C (Int): Input channel count.
  • ​num_loading_warps (Int): Warps cooperating on each load.
  • ​H (Int): Input spatial height.
  • ​W (Int): Input spatial width.
  • ​H_out (Int): Output spatial height (with stride=1, dilation=1, no pad: H - R + 1).
  • ​W_out (Int): Output spatial width.
  • ​R (Int): Filter height.
  • ​S (Int): Filter width.
  • ​stride_h (Int): Vertical conv stride (>= 1).
  • ​stride_w (Int): Horizontal conv stride (>= 1).
  • ​dilation_h (Int): Vertical conv dilation (>= 1).
  • ​dilation_w (Int): Horizontal conv dilation (>= 1).
  • ​pad_h (Int): Vertical pad (>= 0). Halo lanes route to the SRD-OOB sentinel when pad > 0.
  • ​pad_w (Int): Horizontal pad (>= 0).
  • ​Q (Int): Filter temporal extent (3D-only). Q == 1 (default) keeps the loader in 2D mode (4D NHWC input). Q > 1 activates 3D mode (5D NDHWC input, K = QRS*C).
  • ​D (Int): Input temporal depth (3D-only; unused when Q == 1).
  • ​D_out (Int): Output temporal depth (3D-only).
  • ​stride_d (Int): Temporal conv stride (3D-only, >= 1).
  • ​dilation_d (Int): Temporal conv dilation (3D-only, >= 1).
  • ​pad_d (Int): Temporal pad (3D-only, >= 0). Halo lanes route to the SRD-OOB sentinel when pad_d > 0.
  • ​swizzle (Optional[Swizzle]): Optional byte-space swizzle for LDS bank conflicts.
  • ​load_width (Int): Elements per load (SIMD width).
  • ​use_full_tile_width (Bool): FP8 row-major mode (matches TileLoaderLDS.use_full_tile_width).
  • ​use_runtime_hw (Bool): When True, H/W/H_out/W_out (and D/D_out in 3D mode) come from runtime constructor args instead of the comptime template params above. Used for graph-compiled callers with dynamic image resolution (e.g. FLUX VAE). The K-decomposition and conv params (Q, R, S, stride, dilation, pad) stay comptime.

Fields​

  • ​buffer (AMDBufferResource):
  • ​thread_row (Int):
  • ​thread_col (Int):
  • ​warp_id (Int):
  • ​lane_id (Int):
  • ​num_records (Int):
  • ​m_anchor (Int):
  • ​k_anchor (Int):
  • ​rt_h (Int):
  • ​rt_w (Int):
  • ​rt_h_out (Int):
  • ​rt_w_out (Int):
  • ​rt_spatial (Int):
  • ​rt_d (Int):
  • ​rt_d_out (Int):
  • ​rt_spatial_dhw (Int):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TileLoader, TrivialRegisterPassable

comptime members​

dtype​

comptime dtype = dtype_

elements_per_warp​

comptime elements_per_warp = (WARP_SIZE * load_width)

lane_load_bytes​

comptime lane_load_bytes = (load_width * size_of[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype]())

loading_threads​

comptime loading_threads = (num_loading_warps * WARP_SIZE)

num_iterations​

comptime num_iterations = ceildiv(TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_rows, TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].rows_per_iteration)

num_warp_cols​

comptime num_warp_cols = (TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_cols // TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].subtile_cols)

num_warp_rows​

comptime num_warp_rows = (num_loading_warps // TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].num_warp_cols)

row_bytes​

comptime row_bytes = (TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_cols * size_of[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype]())

rows_per_iteration​

comptime rows_per_iteration = (TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].loading_threads // (TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_cols // load_width))

rows_per_warp​

comptime rows_per_warp = (TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].elements_per_warp // TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_cols)

subtile_cols​

comptime subtile_cols = TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_cols if use_full_tile_width else 32

thread_rows​

comptime thread_rows = (WARP_SIZE // TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].threads_per_row)

threads_per_row​

comptime threads_per_row = (TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].subtile_cols // load_width)

tile_cols​

comptime tile_cols = tile_cols_

tile_rows​

comptime tile_rows = tile_rows_

total_warp_rows​

comptime total_warp_rows = ceildiv(TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_rows, TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].rows_per_warp)

warp_subtile_bytes​

comptime warp_subtile_bytes = ((TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].rows_per_warp * TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].tile_cols) * size_of[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype]())

Methods​

__init__​

__init__[InLayout: TensorLayout](src_nhwc: TileTensor[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype, InLayout], warp_id: Int, lane_id: Int, *, m_anchor: Int = 0, k_anchor: Int = 0) -> Self

Builds the loader from a 4D NHWC input TileTensor.

The SRD covers the entire NHWC tensor (N*H*W*C elements). Per-block addressing is split between m_anchor/k_anchor (per-block origin in GEMM space, set at construction) and the m_offset/k_offset args of load_tile (within-block).

This overload is for the comptime-HW path (the default); the runtime conv geometry fields are populated with zeros and the loader uses the comptime template params instead.

Args:

__init__[InLayout: TensorLayout](src_nhwc: TileTensor[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype, InLayout], warp_id: Int, lane_id: Int, *, runtime_h: Int, runtime_w: Int, runtime_h_out: Int, runtime_w_out: Int, m_anchor: Int = 0, k_anchor: Int = 0) -> Self

Runtime-HW overload: H/W/H_out/W_out from runtime args.

Equivalent to the comptime-HW overload except the conv input / output spatial dims are runtime values (typically read from input.dim() / output.dim() by the launcher). Use when the graph compiler can't pin the resolution.

Args:

__init__[InLayout: TensorLayout](src_ndhwc: TileTensor[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype, InLayout], warp_id: Int, lane_id: Int, *, runtime_d: Int, runtime_h: Int, runtime_w: Int, runtime_d_out: Int, runtime_h_out: Int, runtime_w_out: Int, m_anchor: Int = 0, k_anchor: Int = 0) -> Self

3D runtime-HW overload: D/H/W/D_out/H_out/W_out from runtime args.

Equivalent to the 2D runtime-HW overload but for Q > 1 mode: accepts a rank-5 NDHWC TileTensor and runtime D / D_out args in addition to the spatial H/W ones. The K-decomposition and conv params (Q, R, S, stride_d, stride_h, stride_w, dilation_, pad_, C) stay comptime.

Args:

load_tile​

load_tile(self, dst: TileTensor[TileLoaderLDSIm2col[dtype_, tile_rows_, tile_cols_, C, num_loading_warps, H, W, H_out, W_out, R, S, stride_h, stride_w, dilation_h, dilation_w, pad_h, pad_w, Q, D, D_out, stride_d, dilation_d, pad_d, swizzle, load_width, use_full_tile_width, use_runtime_hw].dtype, address_space=AddressSpace.SHARED], m_offset: Int, k_offset: Int)

Loads a half-tile from NHWC global memory into the SMEM dst.

Two paths:

  • Pure-pointwise fast path (R=S=1, stride=1, dilation=1, pad=0): GEMM address addr = ((n*H + h)*W + w)*C + c collapses to addr = m * C + k. Identical to TileLoaderLDS.load_tile with stride = C; the per-lane vs uniform offset split is preserved so each iteration issues one buffer_load_*_lds per lane, matching the matmul's vmcnt accounting.

  • General RΓ—S path (M2): the lane's (m_lane, k_lane) are decomposed at runtime β€” k_lane β†’ (kh, kw, c) via comptime R, S, C divisors (constant-folded to multiply-by-magic); m_lane β†’ (n, h_out, w_out) via comptime H_out, W_out divisors. Then h_in = h_out * stride_h + kh * dilation_h - pad_h (similarly for w_in) and addr = ((n*H + h_in)*W + w_in)*C + c. The full per-lane address goes into vector_offset; scalar_offset = 0. Costs more VGPRs per load than the fast path because the address decomposition can't be cleanly split into a uniform + per-lane pair (the m β†’ (n, h_out, w_out) decomposition is non-linear).

Args: