Skip to main content

Python class

Shardable

Shardable​

class max.nn.Shardable(*args, **kwargs)

source

Bases: Protocol

Protocol for objects that support sharding across multiple devices.

This protocol defines the interface that all shardable components (like Linear layers and Weight objects) must implement to participate in distributed computation.

shard()​

shard(devices)

source

Creates a sharded view of this object for a specific device.

Parameters:

  • device – The devices where this shard should reside.
  • devices (Iterable[DeviceRef])

Returns:

A sequence of sharded instances of this object.

Return type:

Sequence[Self]

sharding_strategy​

property sharding_strategy: ShardingStrategy | None

source

Gets the weight sharding strategy.