Python function
causal_attention_mask_with_token_mask
causal_attention_mask_with_token_mask()β
max.pipelines.modeling.dataprocessing.causal_attention_mask_with_token_mask(original_start_pos, token_mask, *, mask_name='token_mask')
Builds a causal attention mask and additionally masks invalid tokens.
-
Parameters:
-
- original_start_pos (list[int]) β Per-example start position (context length) in the batch.
- token_mask (Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | complex | bytes | str | _NestedSequence[complex | bytes | str]) β Per-example validity mask for tokens in the current pass. Shape [seq_len] or [batch, seq_len]. True marks a valid token and False marks padding or any token that should be hidden.
- mask_name (str) β Name used in validation errors.
-
Returns:
-
Float32 additive mask array where visible positions are 0 and masked positions are a large negative value.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!