Skip to main content
Log in

Mojo function

ChunkedCausalMask

ChunkedCausalMask[local_window_size: Int]() -> OrMask[CausalMask(), ChunkedMask()]

Mask implementing Chunked Causal attention for Llama4 models.

This groups the mask into chunks of size local_window_size and performs causal attention within each local chunk. Considering the following case:

  • Q_len = 7
  • K_len = 10
  • start_pos = 3
  • local_window_size = 4

The mask will be applied as follows: K > 0 1 2 3 4 5 6 7 8 9 Q v x--------------------x 0 | 1 1 1 1 0 0 0 0 0 0 1 | 0 0 0 0 1 0 0 0 0 0 2 | 0 0 0 0 1 1 0 0 0 0 3 | 0 0 0 0 1 1 1 0 0 0 4 | 0 0 0 0 1 1 1 1 0 0 5 | 0 0 0 0 0 0 0 0 1 0 6 | 0 0 0 0 0 0 0 0 1 1

Was this page helpful?