Skip to main content

Python class

DistributedTransformerBlock

DistributedTransformerBlock

class max.nn.DistributedTransformerBlock(attention, mlp, attention_norm, mlp_norm, devices)

source

Bases: Module

Stack of Attention, FeedForward, and RMSNorm layers.

Parameters:

  • attention (Module)
  • mlp (ShardableCallable)
  • attention_norm (ShardableCallable)
  • mlp_norm (ShardableCallable)
  • devices (list[DeviceRef])