IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

Conv2DKernelConfig

struct Conv2DKernelConfig

Conv-specific geometry for AMD4WaveMatmul's conv2d entry point.

Companion to MatmulKernelConfig: the matmul shape config drives the 4-wave block/warp/MMA tiling, while this struct carries the extra conv2d parameters the im2col loader needs to materialize the A operand from a 4D NHWC input. Pass both at the kernel callsite.

Fields​

  • ​H (Int): Input spatial height (NHWC dim 1). Used as H_eff when use_runtime_hw=False; ignored otherwise (loader reads runtime H).
  • ​W (Int): Input spatial width (NHWC dim 2). Same semantics as H.
  • ​H_out (Int): Output spatial height = (H + 2*pad_h - dilation_h*(R-1) - 1) // stride_h + 1. Comptime-only on the static-HW path.
  • ​W_out (Int): Output spatial width = (W + 2*pad_w - dilation_w*(S-1) - 1) // stride_w + 1.
  • ​R (Int): Filter spatial height.
  • ​S (Int): Filter spatial width.
  • ​stride_h (Int): Conv vertical stride (>= 1).
  • ​stride_w (Int): Conv horizontal stride (>= 1).
  • ​dilation_h (Int): Conv vertical dilation (>= 1).
  • ​dilation_w (Int): Conv horizontal dilation (>= 1).
  • ​pad_h (Int): Vertical pad (>= 0). When > 0, halo lanes route to the SRD-OOB sentinel for zero-clamp behavior.
  • ​pad_w (Int): Horizontal pad (>= 0).
  • ​C_in (Int): Real input channel count. When > 0, lets the caller K-pad the filter (allocate filter as [Cout, K_padded] where K_padded = round_up(RSC_in, 2*BK) and zero-fill the trailing K rows). When 0, the kernel derives C_in = K_filter // (R*S) and asserts exact divisibility. In 3D mode the divisor is Q*R*S.
  • ​use_runtime_hw (Bool): When True, H/W/H_out/W_out (and D/D_out in 3D mode) are runtime values read from the input tensor (typically via graph-compiler symbolic resolution). When False, all conv geometry is comptime.
  • ​Q (Int): Filter temporal extent (3D only). Q == 1 keeps the kernel in 2D mode (4D NHWC input, K = RSC). Q > 1 activates 3D mode (5D NDHWC input, K = QRS*C).
  • ​D (Int): Input temporal depth (NDHWC dim 1). 2D mode: unused.
  • ​D_out (Int): Output temporal depth. 2D mode: unused.
  • ​stride_d (Int): Conv temporal stride (>= 1). 2D mode: unused.
  • ​dilation_d (Int): Conv temporal dilation (>= 1). 2D mode: unused.
  • ​pad_d (Int): Temporal pad (>= 0). 2D mode: unused. When > 0, halo lanes route to the SRD-OOB sentinel for zero-clamp behavior.

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, Writable

Methods​

__init__​

def __init__(out self, *, H: Int = 1, W: Int = 1, H_out: Int = 1, W_out: Int = 1, R: Int = 1, S: Int = 1, stride_h: Int = 1, stride_w: Int = 1, dilation_h: Int = 1, dilation_w: Int = 1, pad_h: Int = 0, pad_w: Int = 0, C_in: Int = 0, use_runtime_hw: Bool = False, Q: Int = 1, D: Int = 1, D_out: Int = 1, stride_d: Int = 1, dilation_d: Int = 1, pad_d: Int = 0)

Constructs a Conv2DKernelConfig from the conv geometry.

Args:

  • ​H (Int): Input spatial height.
  • ​W (Int): Input spatial width.
  • ​H_out (Int): Output spatial height.
  • ​W_out (Int): Output spatial width.
  • ​R (Int): Filter spatial height.
  • ​S (Int): Filter spatial width.
  • ​stride_h (Int): Conv vertical stride.
  • ​stride_w (Int): Conv horizontal stride.
  • ​dilation_h (Int): Conv vertical dilation.
  • ​dilation_w (Int): Conv horizontal dilation.
  • ​pad_h (Int): Vertical pad.
  • ​pad_w (Int): Horizontal pad.
  • ​C_in (Int): Real input channel count.
  • ​use_runtime_hw (Bool): When True, treat H/W/H_out/W_out (and D/D_out in 3D mode) as placeholders to be replaced at kernel-launch time with values read from the input tensor.
  • ​Q (Int): Filter temporal extent. Q == 1 (default) keeps the kernel in 2D mode; Q > 1 activates 3D mode.
  • ​D (Int): Input temporal depth (3D mode).
  • ​D_out (Int): Output temporal depth (3D mode).
  • ​stride_d (Int): Conv temporal stride (3D mode).
  • ​dilation_d (Int): Conv temporal dilation (3D mode).
  • ​pad_d (Int): Temporal pad (3D mode).

write_to​

def write_to(self, mut writer: T)

Writes a compact conv geometry tag to writer.

Args:

  • ​writer (T): Sink for the rendered tag.

write_repr_to​

def write_repr_to(self, mut writer: T)

Writes a debug representation of this conv config.

Args:

  • ​writer (T): Sink for the rendered tag.