Skip to main content

Python module

max.nn.attention

The attention mechanism used within the model.

Attention layersโ€‹

AttentionWithRopeImplementation of attention that uses Rotary Position Embedding (RoPE).
DistributedAttentionImplA generalized Distributed attention interface.
GGUFQAttentionWithRopeImplementation of attention with GGUF quantized weights.
GPTQAttentionWithRopeImplementation of the GPTQ attention layer.
LatentAttentionWithRopeImplementation of Latent Attention with Rope.
MultiheadAttentionMultihead attention that handles both single and distributed computation.
RaggedAttentionLayer that computes the self attention score for ragged inputs.
TensorParallelAttentionWithRopeTensor-parallel wrapper that delegates sharding to the base module.
TensorParallelLatentAttentionWithRopeDistributed tensor parallel implementation of the Latent Attention with Rope.

Mask configurationโ€‹

AttentionMaskVariantDefines the string mask variant identifiers used in attention configuration.
MHAMaskVariantDefines the integer mask variant codes used by multihead attention kernels.

Functionsโ€‹

num_heads_for_deviceComputes the number of attention heads assigned to a specific device.