Skip to main content
Log in

Python module

distributed_transformer

DistributedTransformer

class max.nn.transformer.distributed_transformer.DistributedTransformer(dim, n_heads, layers, norm, output, embedding, kv_params, kv_collection_constructor, devices, return_logits=ReturnLogits.LAST_TOKEN)

Transformer model consisting for TransformerBlock layers.

Parameters:

DistributedTransformerBlock

class max.nn.transformer.distributed_transformer.DistributedTransformerBlock(attention, mlp, attention_norm, mlp_norm, devices, use_subgraph=False)

Stack of Attention, FeedForward, and RMSNorm layers.

Parameters:

build_subgraph()

build_subgraph(name)

Parameters:

name (str )

Return type:

Module

distribute_value()

max.nn.transformer.distributed_transformer.distribute_value(v, devices)

Parameters:

devices (list [ DeviceRef ] )

take()

max.nn.transformer.distributed_transformer.take(it, n)

Return the next n items from it as a list.

Parameters:

Return type:

list[Value]