Skip to main content
Log in

Mojo trait

MHAMask

The MHAMask trait describes masks for MHA kernels, such as the causal mask.

Implemented traits

AnyType, UnknownDestructibility

Aliases

apply_log2e_after_mask

alias apply_log2e_after_mask

Does the mask require log2e to be applied after the mask, or can it be fused with the scaling?

mask_out_of_bound

alias mask_out_of_bound

mask_safe_out_of_bounds

alias mask_safe_out_of_bounds

Is the mask safe to read out of bounds?

Methods

mask

mask[type: DType, width: Int, //, *, element_type: DType = uint32](self: _Self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[type, width]) -> SIMD[type, width]

Return mask vector at given coordinates.

Arguments: coord is (seq_id, head, q_idx, k_idx) score_vec is at coord of the score matrix

The functor could capture an mask tensor and add to the score e.g. Replit.

status

status[*, element_type: DType = uint32](self: _Self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus

Given a tile's index range, return its masking status.