Python class
ColumnParallelLinear
ColumnParallelLinear
class max.nn.ColumnParallelLinear(in_dim, out_dim, dtype, devices, tied_weight=None, **kwargs)
Bases: Linear
A Linear layer where the weight and bias are sharded onto multiple devices.
This layer first computes for each device i in [0,…, num_devices]:
+-----+ +-----+ T +-----+ +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+The values are then collected using an Allgather op, producing the same output tensor on each device:
GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+Example usage:
from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear
num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)Initializes the column-parallel linear layer.
-
Parameters:
-
- in_dim (int) – The dimensionality of the input space.
- out_dim (int) – The dimensionality of the output space.
- dtype (DType) – The
DTypefor both weights and bias. - devices (Sequence[DeviceRef]) – The target
DeviceRefdevices for computation. Weights remain on CPU until sharded and moved to device during computation. - tied_weight (Weight | None) – Optional
Weightto tie with this layer. - **kwargs – Additional keyword arguments passed to the Linear initializer.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!