Skip to main content

Python class

ColumnParallelLinear

ColumnParallelLinear

class max.nn.ColumnParallelLinear(in_dim, out_dim, dtype, devices, tied_weight=None, **kwargs)

source

Bases: Linear

A Linear layer where the weight and bias are sharded onto multiple devices.

This layer first computes y=xWiT+biy = xW_i^T + b_i for each device i in [0,…, num_devices]:

+-----+       +-----+ T     +-----+       +-----+
|     |       | W_0 |       | b_0 |       | y_0 | GPU0
|     |       +-----+       +-----+       +-----+
|     |       | W_1 |       | b_1 |       | y_1 | GPU1
|  x  |   @   +-----+   +   +-----+   =   +-----+
|     |       | W_2 |       | b_2 |       | y_2 | GPU2
|     |       +-----+       +-----+       +-----+
|     |       | W_3 |       | b_3 |       | y_3 | GPU3
+-----+       +-----+       +-----+       +-----+

The values are then collected using an Allgather op, producing the same output tensor y=xWT+by = xW^T + b on each device:

GPU0  GPU1  GPU2  GPU3                      GPU0  GPU1  GPU2  GPU3
+-----+-----+-----+-----+                   +-----+-----+-----+-----+
| y_0 |  -  |  -  |  -  |                   | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+                   +-----+-----+-----+-----+
|  -  | y_1 |  -  |  -  |                   | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+  -- Allgather --> +-----+-----+-----+-----+
|  -  |  -  | y_2 |  -  |                   | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+                   +-----+-----+-----+-----+
|  -  |  -  |  -  | y_3 |                   | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+                   +-----+-----+-----+-----+

Example usage:

from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear

num_devices = 4
distributed_linear = ColumnParallelLinear(
    in_dim,
    out_dim,
    DType.float32,
    devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)

Initializes the column-parallel linear layer.

Parameters:

  • in_dim (int) – The dimensionality of the input space.
  • out_dim (int) – The dimensionality of the output space.
  • dtype (DType) – The DType for both weights and bias.
  • devices (Sequence[DeviceRef]) – The target DeviceRef devices for computation. Weights remain on CPU until sharded and moved to device during computation.
  • tied_weight (Weight | None) – Optional Weight to tie with this layer.
  • **kwargs – Additional keyword arguments passed to the Linear initializer.