Mojo module
mha_mask
comptime valuesβ
MASK_VALUEβ
comptime MASK_VALUE = -10000
Structsβ
- β
AndMask: Mask that's the AND of two masks. - β
CausalMask: MHA causal mask ensures a token is only affected by previous tokens. - β
CausalPaddingMask: Causal mask combined with padding: a position (seq_id, head, q, k) is visible only when q >= k (causal) AND k < valid_lengths[seq_id] (padding). - β
ChunkedMask: Mask implementing Chunked attention. - β
MaskName: A tile's masking status. - β
MaskStrategy: - β
MaterializedMask: Mask that's backed by a materialized tensor. - β
NullMask: Mask that's effectively a noop. - β
OrMask: Mask that's the OR of two masks. - β
SlidingWindowCausalMask: Mask implementing Sliding Window attention. - β
TileMaskStatus: A tile's masking status.
Traitsβ
- β
MHAMask: The MHAMask trait describes masks for MHA kernels, such as the causal mask.
Functionsβ
- β
ChunkedCausalMask: Mask implementing Chunked Causal attention for Llama4 models. - β
naively_compute_total_iters: - β
naively_get_first_nonempty_mask_col:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!