IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

Conv2d

Conv2d​

class max.nn.Conv2d(kernel_size, in_channels, out_channels, dtype, stride=1, padding=0, dilation=1, num_groups=1, device=None, has_bias=False, permute=False, name=None)

source

Bases: Module, Shardable

A 2D convolution over an input signal composed of several input planes.

When called, Conv2d accepts a TensorValue of shape (batch, height, width, in_channels) and returns a TensorValue of shape (batch, new_height, new_width, out_channels). If permute=True, the input and output follow PyTorch channel-first layout: (batch, in_channels, height, width) and (batch, out_channels, new_height, new_width).

conv = nn.Conv2d(
    kernel_size=3,
    in_channels=64,
    out_channels=128,
    dtype=DType.float32,
    stride=1,
    padding=0,
    has_bias=False,
    name="conv2d_weight",
    device=DeviceRef.GPU(),
)

Initializes the Conv2d layer with weights and optional bias.

Parameters:

  • kernel_size (int | tuple[int, int]) – Size of the convolving kernel. Can be a single int (square kernel) or tuple (height, width).
  • in_channels (int) – Number of channels in the input image.
  • out_channels (int) – Number of channels produced by the convolution.
  • dtype (DType) – The data type for both weights and bias.
  • stride (tuple[int, int]) – Stride of the convolution for height and width dimensions. Can be int (applied to both dimensions) or tuple (stride_h, stride_w). Default: 1
  • padding (tuple[int, int, int, int]) – Padding added to input. Can be int (applied to all sides), tuple of 2 ints (pad_h, pad_w), or tuple of 4 ints (pad_top, pad_bottom, pad_left, pad_right) to support asymmetric padding. Default: 0
  • dilation (tuple[int, int]) – Spacing between kernel elements for height and width dimensions. Can be int (applied to both dimensions) or tuple (dilation_h, dilation_w). Default: 1
  • num_groups (int) – Number of blocked connections from input channels to output channels. Input channels and output channels are divided into groups. Default: 1
  • device (DeviceRef | None) – The target device for computation. If None, defaults to CPU. Weights are initially stored on CPU and moved to target device during computation.
  • name (str | None) – Base name for weights. If provided, weights are named {name}.weight and {name}.bias (if bias is enabled). If None, uses β€œweight” and β€œbias”.
  • has_bias (bool) – If true, adds a learnable bias vector to the layer. Defaults to False.
  • permute (bool) – If true, permutes weights from PyTorch format to MAX format. PyTorch order: (out_channels, in_channels / num_groups, height, width). MAX API order: (height, width, in_channels / num_groups, out_channels). Defaults to False.

bias​

bias: Weight | None = None

source

The optional bias vector stored on CPU with shape (out_channels,). Model init moves the bias to device if present.

device​

device: DeviceRef | None

source

The device where matrix operations are performed.

dilation​

dilation: tuple[int, int]

source

Controls the dilation rate.

filter​

filter: Weight

source

The weight matrix stored on CPU with shape (height, width, in_channels / num_groups, out_channels). Model init moves the weight to device.

num_groups​

num_groups: int

source

Number of blocked connections from input channels to output channels.

padding​

padding: tuple[int, int, int, int]

source

Controls the amount of padding applied before and after the input for height and width dimensions.

Format: (pad_top, pad_bottom, pad_left, pad_right).

permute​

permute: bool = False

source

bool controls whether self.filter is permuted from PyTorch order to max order. PyTorch order is: (out_channels, in_channels / num_groups, height, width) Max API order: (height, width, in_channels / num_groups, out_channels).

shard()​

shard(devices)

source

Creates sharded views of this Conv2d layer across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded Conv2d instances, one for each device.

Return type:

list[Conv2d]

sharding_strategy​

property sharding_strategy: ShardingStrategy | None

source

Get the Conv2d sharding strategy.

stride​

stride: tuple[int, int]

source

Controls the stride for the cross-correlation.