Mojo function
ChunkedCausalMask
ChunkedCausalMask[local_window_size: Int]() -> OrMask[CausalMask(), ChunkedMask()]
Mask implementing Chunked Causal attention for Llama4 models.
This groups the mask into chunks of size local_window_size
and performs causal
attention within each local chunk. Considering the following case:
- Q_len = 7
- K_len = 10
- start_pos = 3
- local_window_size = 4
The mask will be applied as follows: K > 0 1 2 3 4 5 6 7 8 9 Q v x--------------------x 0 | 1 1 1 1 0 0 0 0 0 0 1 | 0 0 0 0 1 0 0 0 0 0 2 | 0 0 0 0 1 1 0 0 0 0 3 | 0 0 0 0 1 1 1 0 0 0 4 | 0 0 0 0 1 1 1 1 0 0 5 | 0 0 0 0 0 0 0 0 1 0 6 | 0 0 0 0 0 0 0 0 1 1
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!