Skip to main content

Python module

distributed_transformer

DistributedLogitsPostprocessMixin

class max.nn.transformer.distributed_transformer.DistributedLogitsPostprocessMixin

Mixin providing logits postprocessing for multi-device sharded models.

Requires: self.norm_shards, self.lm_head, self.return_logits, self.devices. Optional: self.return_hidden_states, self.logits_scaling.

devices

devices: list[DeviceRef]

lm_head

lm_head: Callable[[list[TensorValue], Sequence[BufferValue]], Sequence[TensorValue]]

logits_scaling

logits_scaling: float = 1.0

norm_shards

norm_shards: Sequence[Callable[[TensorValue], TensorValue]]

return_hidden_states

return_hidden_states: ReturnHiddenStates = 'none'

return_logits

return_logits: ReturnLogits

DistributedTransformer

class max.nn.transformer.distributed_transformer.DistributedTransformer(dim, n_heads, layers, norm, output, embedding, kv_params, devices, rope, return_logits=ReturnLogits.LAST_TOKEN, use_subgraphs=False, subgraph_layer_groups=None, logits_scaling=1.0)

Transformer model consisting for TransformerBlock layers.

Parameters:

DistributedTransformerBlock

class max.nn.transformer.distributed_transformer.DistributedTransformerBlock(attention, mlp, attention_norm, mlp_norm, devices)

Stack of Attention, FeedForward, and RMSNorm layers.

Parameters:

ShardableCallable

class max.nn.transformer.distributed_transformer.ShardableCallable(*args, **kwargs)

distributed_logits_postprocess()

max.nn.transformer.distributed_transformer.distributed_logits_postprocess(h, input_row_offsets, return_n_logits, norm_shards, lm_head, signal_buffers, return_logits, device, return_hidden_states=ReturnHiddenStates.NONE, logits_scaling=1.0)

Common logits postprocessing for multi-device sharded models.

Handles last-token gathering, logits computation (VARIABLE/ALL/LAST_TOKEN), logits scaling, and hidden states return for models that use per-device sharded hidden states.

Parameters:

Returns:

Tuple of (last_logits, [logits, offsets], [hidden_states]).

Return type:

tuple[TensorValue, …]

forward_sharded_layers()

max.nn.transformer.distributed_transformer.forward_sharded_layers(layers, xs)

Forward pass through sharded layers.

Parameters:

Returns:

List of output tensors from each layer

Raises:

AssertionError – If the number of layers and input tensors don’t match

Return type:

list[TensorValue]

take()

max.nn.transformer.distributed_transformer.take(it, n)

Return the next n items from it as a list.

Parameters:

Return type:

list[Value[Any]]

Was this page helpful?