Mojo struct
CausalPaddingMask
struct CausalPaddingMask[layout_: Layout, origin_: ImmutOrigin]
Causal mask combined with padding: a position (seq_id, head, q, k) is visible only when q >= k (causal) AND k < valid_lengths[seq_id] (padding).
valid_lengths is a tensor of shape [num_seqs] with one uint32 value per
sequence indicating the number of valid (non-padding) tokens.
Fields
- valid_lengths (
LayoutTensor[DType.uint32, layout_, origin_]):
Implemented traits
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
MHAMask,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
apply_log2e_after_mask
comptime apply_log2e_after_mask = False
causal_mask
comptime causal_mask = CausalMask()
check_mask_during_decoding
comptime check_mask_during_decoding = True
device_type
comptime device_type = CausalPaddingMask[layout_, origin_]
mask_out_of_bound
comptime mask_out_of_bound = is_nvidia_gpu()
mask_safe_out_of_bounds
comptime mask_safe_out_of_bounds = True
Methods
__init__
__init__(valid_lengths: LayoutTensor[DType.uint32, layout_, origin_]) -> Self
get_type_name
name
mask
mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]
Returns:
status
status[*, element_type: DType = DType.uint32](self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus
Returns:
TileMaskStatus
start_column
total_iters
total_iters[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32
Returns:
count_nonfull_sets
last_masked_set_end
last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32
Returns:
masked_set_ends
masked_set_ends[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, CausalPaddingMask.count_nonfull_sets(BM, BN)]
Returns:
nonfull_sets
static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, CausalPaddingMask.count_nonfull_sets(BM, BN)]
Returns:
mask_strategies
static mask_strategies[BM: Int, BN: Int]() -> StaticTuple[MaskStrategy, CausalPaddingMask.count_nonfull_sets(BM, BN)]
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!