Skip to main content

Python class

Conv2d

Conv2d

class max.nn.Conv2d(kernel_size, in_channels, out_channels, dtype, stride=1, padding=0, dilation=1, num_groups=1, device=None, has_bias=False, permute=False, name=None)

source

Bases: Module, Shardable

A 2D convolution over an input signal composed of several input planes.

When called, Conv2d accepts a TensorValue of shape (batch, height, width, in_channels) and returns a TensorValue of shape (batch, new_height, new_width, out_channels). If permute=True, the input and output follow PyTorch channel-first layout: (batch, in_channels, height, width) and (batch, out_channels, new_height, new_width).

conv = nn.Conv2d(
    kernel_size=3,
    in_channels=64,
    out_channels=128,
    dtype=DType.float32,
    stride=1,
    padding=0,
    has_bias=False,
    name="conv2d_weight",
    device=DeviceRef.GPU(),
)

Initializes the Conv2d layer with weights and optional bias.

Parameters:

  • kernel_size (int | tuple[int, int]) – Size of the convolving kernel. Can be a single int (square kernel) or tuple (height, width).
  • in_channels (int) – Number of channels in the input image.
  • out_channels (int) – Number of channels produced by the convolution.
  • dtype (DType) – The data type for both weights and bias.
  • stride (tuple[int, int]) – Stride of the convolution for height and width dimensions. Can be int (applied to both dimensions) or tuple (stride_h, stride_w). Default: 1
  • padding (tuple[int, int, int, int]) – Padding added to input. Can be int (applied to all sides), tuple of 2 ints (pad_h, pad_w), or tuple of 4 ints (pad_top, pad_bottom, pad_left, pad_right) to support asymmetric padding. Default: 0
  • dilation (tuple[int, int]) – Spacing between kernel elements for height and width dimensions. Can be int (applied to both dimensions) or tuple (dilation_h, dilation_w). Default: 1
  • num_groups (int) – Number of blocked connections from input channels to output channels. Input channels and output channels are divided into groups. Default: 1
  • device (DeviceRef | None) – The target device for computation. If None, defaults to CPU. Weights are initially stored on CPU and moved to target device during computation.
  • name (str | None) – Base name for weights. If provided, weights are named {name}.weight and {name}.bias (if bias is enabled). If None, uses “weight” and “bias”.
  • has_bias (bool) – If true, adds a learnable bias vector to the layer. Defaults to False.
  • permute (bool) – If true, permutes weights from PyTorch format to MAX format. PyTorch order: (out_channels, in_channels / num_groups, height, width). MAX API order: (height, width, in_channels / num_groups, out_channels). Defaults to False.

bias

bias: Weight | None = None

source

The optional bias vector stored on CPU with shape (out_channels,). Model init moves the bias to device if present.

device

device: DeviceRef | None

source

The device where matrix operations are performed.

dilation

dilation: tuple[int, int]

source

Controls the dilation rate.

filter

filter: Weight

source

The weight matrix stored on CPU with shape (height, width, in_channels / num_groups, out_channels). Model init moves the weight to device.

num_groups

num_groups: int

source

Number of blocked connections from input channels to output channels.

padding

padding: tuple[int, int, int, int]

source

Controls the amount of padding applied before and after the input for height and width dimensions.

Format: (pad_top, pad_bottom, pad_left, pad_right).

permute

permute: bool = False

source

bool controls whether self.filter is permuted from PyTorch order to max order. PyTorch order is: (out_channels, in_channels / num_groups, height, width) Max API order: (height, width, in_channels / num_groups, out_channels).

shard()

shard(devices)

source

Creates sharded views of this Conv2d layer across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded Conv2d instances, one for each device.

Return type:

list[Conv2d]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

source

Get the Conv2d sharding strategy.

stride

stride: tuple[int, int]

source

Controls the stride for the cross-correlation.