Skip to main content

Python class

TensorParallelLatentAttentionWithRope

TensorParallelLatentAttentionWithRope​

class max.nn.attention.TensorParallelLatentAttentionWithRope(*, skip_allreduce=False, **kwargs)

source

Bases: LatentAttentionWithRope

Distributed tensor parallel implementation of the Latent Attention with Rope. Note that using tensor parallelism for MLA will cause the KV-cache to be duplicated across all devices, which is not efficient.

When skip_allreduce is True, the final all-reduce is skipped. This is intended for mixed TP-attention + EP-MoE configurations, where the communication is handled explicitly by the caller.

Parameters:

skip_allreduce (bool)

create_mla_prefill_metadata()​

create_mla_prefill_metadata(input_row_offsets_, kv_collections)

source

Creates per-device MLA prefill metadata for tensor-parallel execution.

Parameters:

Returns:

A list of MLAPrefillMetadata instances, one per device.

Return type:

list[MLAPrefillMetadata]