Mojo struct
Conv2DKernelConfig
struct Conv2DKernelConfig
Conv-specific geometry for AMD4WaveMatmul's conv2d entry point.
Companion to MatmulKernelConfig: the matmul shape config drives
the 4-wave block/warp/MMA tiling, while this struct carries the
extra conv2d parameters the im2col loader needs to materialize the
A operand from a 4D NHWC input. Pass both at the kernel callsite.
Fieldsβ
- βH (
Int): Input spatial height (NHWC dim 1). Used asH_effwhenuse_runtime_hw=False; ignored otherwise (loader reads runtime H). - βW (
Int): Input spatial width (NHWC dim 2). Same semantics asH. - βH_out (
Int): Output spatial height =(H + 2*pad_h - dilation_h*(R-1) - 1) // stride_h + 1. Comptime-only on the static-HW path. - βW_out (
Int): Output spatial width =(W + 2*pad_w - dilation_w*(S-1) - 1) // stride_w + 1. - βR (
Int): Filter spatial height. - βS (
Int): Filter spatial width. - βstride_h (
Int): Conv vertical stride (>= 1). - βstride_w (
Int): Conv horizontal stride (>= 1). - βdilation_h (
Int): Conv vertical dilation (>= 1). - βdilation_w (
Int): Conv horizontal dilation (>= 1). - βpad_h (
Int): Vertical pad (>= 0). When > 0, halo lanes route to the SRD-OOB sentinel for zero-clamp behavior. - βpad_w (
Int): Horizontal pad (>= 0). - βC_in (
Int): Real input channel count. When > 0, lets the caller K-pad the filter (allocate filter as[Cout, K_padded]where K_padded = round_up(RSC_in, 2*BK) and zero-fill the trailing K rows). When 0, the kernel derivesC_in = K_filter // (R*S)and asserts exact divisibility. - βuse_runtime_hw (
Bool): When True, H/W/H_out/W_out are runtime values read from the input tensor (typically via graph-compiler symbolic resolution). When False, all conv geometry is comptime.
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
Writable
Methodsβ
__init__β
__init__(out self, *, H: Int = 1, W: Int = 1, H_out: Int = 1, W_out: Int = 1, R: Int = 1, S: Int = 1, stride_h: Int = 1, stride_w: Int = 1, dilation_h: Int = 1, dilation_w: Int = 1, pad_h: Int = 0, pad_w: Int = 0, C_in: Int = 0, use_runtime_hw: Bool = False)
Constructs a Conv2DKernelConfig from the conv geometry.
Args:
- βH (
Int): Input spatial height. - βW (
Int): Input spatial width. - βH_out (
Int): Output spatial height. - βW_out (
Int): Output spatial width. - βR (
Int): Filter spatial height. - βS (
Int): Filter spatial width. - βstride_h (
Int): Conv vertical stride. - βstride_w (
Int): Conv horizontal stride. - βdilation_h (
Int): Conv vertical dilation. - βdilation_w (
Int): Conv horizontal dilation. - βpad_h (
Int): Vertical pad. - βpad_w (
Int): Horizontal pad. - βC_in (
Int): Real input channel count. - βuse_runtime_hw (
Bool): When True, treat H/W/H_out/W_out as placeholders to be replaced at kernel-launch time with values read from the input tensor.
write_toβ
write_to(self, mut writer: T)
Writes a compact conv geometry tag to writer.
Args:
- βwriter (
T): Sink for the rendered tag.
write_repr_toβ
write_repr_to(self, mut writer: T)
Writes a debug representation of this conv config.
Args:
- βwriter (
T): Sink for the rendered tag.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!