IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python function

causal_attention_mask_with_token_mask

causal_attention_mask_with_token_mask()​

max.pipelines.modeling.dataprocessing.causal_attention_mask_with_token_mask(original_start_pos, token_mask, *, mask_name='token_mask')

source

Builds a causal attention mask and additionally masks invalid tokens.

Parameters:

  • original_start_pos (list[int]) – Per-example start position (context length) in the batch.
  • token_mask (Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | complex | bytes | str | _NestedSequence[complex | bytes | str]) – Per-example validity mask for tokens in the current pass. Shape [seq_len] or [batch, seq_len]. True marks a valid token and False marks padding or any token that should be hidden.
  • mask_name (str) – Name used in validation errors.

Returns:

Float32 additive mask array where visible positions are 0 and masked positions are a large negative value.

Return type:

ndarray[tuple[Any, …], dtype[float32]]