IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

ChunkedCausalMask

def ChunkedCausalMask[local_window_size: Int]() -> OrMask[CausalMask(), ChunkedMask()]

Mask implementing Chunked Causal attention for Llama4 models.

This groups the mask into chunks of size local_window_size and performs causal attention within each local chunk. Considering the following case:

  • Q_len = 7
  • K_len = 10
  • start_pos = 3
  • local_window_size = 4

The mask will be applied as follows: K > 0 1 2 3 4 5 6 7 8 9 Q v x--------------------x 0 | 1 1 1 1 0 0 0 0 0 0 1 | 0 0 0 0 1 0 0 0 0 0 2 | 0 0 0 0 1 1 0 0 0 0 3 | 0 0 0 0 1 1 1 0 0 0 4 | 0 0 0 0 1 1 1 1 0 0 5 | 0 0 0 0 0 0 0 0 1 0 6 | 0 0 0 0 0 0 0 0 1 1

Returns:

OrMask[CausalMask(), ChunkedMask()]