Skip to main content
Log in

Mojo struct

CausalMask

@register_passable(trivial) struct CausalMask

MHA causal mask ensures a token is only affected by previous tokens.

Aliases

  • apply_log2e_after_mask = False:
  • mask_out_of_bound = is_nvidia_gpu():
  • mask_safe_out_of_bounds = True:

Implemented traits

AnyType, Copyable, ExplicitlyCopyable, MHAMask, Movable, UnknownDestructibility

Methods

mask

mask[type: DType, width: Int, //, *, element_type: DType = uint32](self, coord: Index[4, element_type=element_type], score_vec: SIMD[type, width]) -> SIMD[type, width]

status

status[*, element_type: DType = uint32](self, tile_offset: Index[2, element_type=element_type], tile_size: Index[2, element_type=element_type]) -> TileMaskStatus

Was this page helpful?