Skip to main content

Mojo function

peel_mask

peel_mask[num_sets: Int, //, mask_strategies: StaticTuple[MaskStrategy, num_sets], load_fn: def[mask_strategy: MaskStrategy](UInt32) capturing -> Float32](mut mask_iters: StaticTuple[UInt32, num_sets], kv_row: UInt32) -> Float32

Determine which mask strategy applies to the peeled first iteration.

Walks through mask sets to find the first with remaining iterations, calls load_fn with the corresponding strategy, and decrements the counter. Prevents UInt32 underflow when early sets are empty (e.g. SlidingWindowCausalMask with num_sets=3 and small sequences).

Returns:

Float32

Was this page helpful?