Skip to main content

Mojo struct

Conv2DKernelConfig

struct Conv2DKernelConfig

Conv-specific geometry for AMD4WaveMatmul's conv2d entry point.

Companion to MatmulKernelConfig: the matmul shape config drives the 4-wave block/warp/MMA tiling, while this struct carries the extra conv2d parameters the im2col loader needs to materialize the A operand from a 4D NHWC input. Pass both at the kernel callsite.

Fields​

  • ​H (Int): Input spatial height (NHWC dim 1). Used as H_eff when use_runtime_hw=False; ignored otherwise (loader reads runtime H).
  • ​W (Int): Input spatial width (NHWC dim 2). Same semantics as H.
  • ​H_out (Int): Output spatial height = (H + 2*pad_h - dilation_h*(R-1) - 1) // stride_h + 1. Comptime-only on the static-HW path.
  • ​W_out (Int): Output spatial width = (W + 2*pad_w - dilation_w*(S-1) - 1) // stride_w + 1.
  • ​R (Int): Filter spatial height.
  • ​S (Int): Filter spatial width.
  • ​stride_h (Int): Conv vertical stride (>= 1).
  • ​stride_w (Int): Conv horizontal stride (>= 1).
  • ​dilation_h (Int): Conv vertical dilation (>= 1).
  • ​dilation_w (Int): Conv horizontal dilation (>= 1).
  • ​pad_h (Int): Vertical pad (>= 0). When > 0, halo lanes route to the SRD-OOB sentinel for zero-clamp behavior.
  • ​pad_w (Int): Horizontal pad (>= 0).
  • ​C_in (Int): Real input channel count. When > 0, lets the caller K-pad the filter (allocate filter as [Cout, K_padded] where K_padded = round_up(RSC_in, 2*BK) and zero-fill the trailing K rows). When 0, the kernel derives C_in = K_filter // (R*S) and asserts exact divisibility.
  • ​use_runtime_hw (Bool): When True, H/W/H_out/W_out are runtime values read from the input tensor (typically via graph-compiler symbolic resolution). When False, all conv geometry is comptime.

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, Writable

Methods​

__init__​

__init__(out self, *, H: Int = 1, W: Int = 1, H_out: Int = 1, W_out: Int = 1, R: Int = 1, S: Int = 1, stride_h: Int = 1, stride_w: Int = 1, dilation_h: Int = 1, dilation_w: Int = 1, pad_h: Int = 0, pad_w: Int = 0, C_in: Int = 0, use_runtime_hw: Bool = False)

Constructs a Conv2DKernelConfig from the conv geometry.

Args:

  • ​H (Int): Input spatial height.
  • ​W (Int): Input spatial width.
  • ​H_out (Int): Output spatial height.
  • ​W_out (Int): Output spatial width.
  • ​R (Int): Filter spatial height.
  • ​S (Int): Filter spatial width.
  • ​stride_h (Int): Conv vertical stride.
  • ​stride_w (Int): Conv horizontal stride.
  • ​dilation_h (Int): Conv vertical dilation.
  • ​dilation_w (Int): Conv horizontal dilation.
  • ​pad_h (Int): Vertical pad.
  • ​pad_w (Int): Horizontal pad.
  • ​C_in (Int): Real input channel count.
  • ​use_runtime_hw (Bool): When True, treat H/W/H_out/W_out as placeholders to be replaced at kernel-launch time with values read from the input tensor.

write_to​

write_to(self, mut writer: T)

Writes a compact conv geometry tag to writer.

Args:

  • ​writer (T): Sink for the rendered tag.

write_repr_to​

write_repr_to(self, mut writer: T)

Writes a debug representation of this conv config.

Args:

  • ​writer (T): Sink for the rendered tag.