Skip to main content

Python function

causal_attention_mask

causal_attention_mask()

max.pipelines.modeling.dataprocessing.causal_attention_mask(original_start_pos, original_seq_len)

source

Builds a causal attention mask for a batch of variable-length sequences.

Parameters:

  • original_start_pos (list[int]) – Per-example start position (context length) in the batch.
  • original_seq_len (list[int]) – Per-example sequence length for this pass.

Returns:

Float32 mask array where visible positions are 0 and masked positions are a large negative value (so that softmax treats them as -inf).

Return type:

ndarray[tuple[Any, …], dtype[float32]]