Skip to main content

Python function

causal_attention_mask_with_token_mask

causal_attention_mask_with_token_mask()​

max.pipelines.modeling.dataprocessing.causal_attention_mask_with_token_mask(original_start_pos, token_mask, *, mask_name='token_mask')

source

Builds a causal attention mask and additionally masks invalid tokens.

Parameters:

  • original_start_pos (list[int]) – Per-example start position (context length) in the batch.
  • token_mask (Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | complex | bytes | str | _NestedSequence[complex | bytes | str]) – Per-example validity mask for tokens in the current pass. Shape [seq_len] or [batch, seq_len]. True marks a valid token and False marks padding or any token that should be hidden.
  • mask_name (str) – Name used in validation errors.

Returns:

Float32 additive mask array where visible positions are 0 and masked positions are a large negative value.

Return type:

ndarray[tuple[Any, …], dtype[float32]]