Python class
TensorParallelLatentAttentionWithRope
TensorParallelLatentAttentionWithRopeβ
class max.nn.attention.TensorParallelLatentAttentionWithRope(*, skip_allreduce=False, **kwargs)
Bases: LatentAttentionWithRope
Distributed tensor parallel implementation of the Latent Attention with Rope. Note that using tensor parallelism for MLA will cause the KV-cache to be duplicated across all devices, which is not efficient.
When skip_allreduce is True, the final all-reduce is skipped. This is
intended for mixed TP-attention + EP-MoE configurations, where the
communication is handled explicitly by the caller.
-
Parameters:
-
skip_allreduce (bool)
create_mla_prefill_metadata()β
create_mla_prefill_metadata(input_row_offsets_, kv_collections)
Creates per-device MLA prefill metadata for tensor-parallel execution.
-
Parameters:
-
- input_row_offsets β Per-device ragged row offset tensors.
- kv_collections (list[KVCacheInputsPerDevice[TensorValue, BufferValue]]) β Per-device paged KV cache values.
- input_row_offsets_ (list[TensorValue])
-
Returns:
-
A list of
MLAPrefillMetadatainstances, one per device. -
Return type:
-
list[MLAPrefillMetadata]
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!