Mojo struct
SlidingWindowCausalMask
@register_passable(trivial)
struct SlidingWindowCausalMask[window_size: Int]
Mask implementing Sliding Window attention.
Considering the following case:
- Q_len = 7
- K_len = 7
- window_size = 3
The mask will be applied as follows: K > 0 1 2 3 4 5 6 Q v x------------x 0 | 1 0 0 0 0 0 0 1 | 1 1 0 0 0 0 0 2 | 1 1 1 0 0 0 0 3 | 0 1 1 1 0 0 0 4 | 0 0 1 1 1 0 0 5 | 0 0 0 1 1 1 0 6 | 0 0 0 0 1 1 1
Aliases
apply_log2e_after_mask = False
:mask_out_of_bound = True
:mask_safe_out_of_bounds = True
:
Implemented traits
AnyType
,
Copyable
,
ExplicitlyCopyable
,
MHAMask
,
Movable
,
UnknownDestructibility
Methods
mask
mask[type: DType, width: Int, //, *, element_type: DType = uint32](self, coord: Index[4, element_type=element_type], score_vec: SIMD[type, width]) -> SIMD[type, width]
status
status[*, element_type: DType = uint32](self, tile_offset: Index[2, element_type=element_type], tile_size: Index[2, element_type=element_type]) -> TileMaskStatus
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!