Mojo function
peel_mask
peel_mask[num_sets: Int, //, mask_strategies: StaticTuple[MaskStrategy, num_sets], load_fn: def[mask_strategy: MaskStrategy](UInt32) capturing -> Float32](mut mask_iters: StaticTuple[UInt32, num_sets], kv_row: UInt32) -> Float32
Determine which mask strategy applies to the peeled first iteration.
Walks through mask sets to find the first with remaining iterations, calls load_fn with the corresponding strategy, and decrements the counter. Prevents UInt32 underflow when early sets are empty (e.g. SlidingWindowCausalMask with num_sets=3 and small sequences).
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!