For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
SlidingWindowNonCausalMask
struct SlidingWindowNonCausalMask[window_size: Int]
Non-causal sliding-window attention mask.
A (q, k) pair is visible iff k + window_size > q. Unlike
SlidingWindowCausalMask there is no causal upper bound, so future keys
(k > q) are always visible: a windowed context plus a bidirectional
block. Used by windowed block-diffusion speculative-decode drafts (DFlash).
Example with Q_len = K_len = 7, window_size = 3 (upper triangle all 1s,
unlike SlidingWindowCausalMask):
K > 0 1 2 3 4 5 6
Q v x------------x
0 | 1 1 1 1 1 1 1
1 | 1 1 1 1 1 1 1
2 | 1 1 1 1 1 1 1
3 | 0 1 1 1 1 1 1
4 | 0 0 1 1 1 1 1
5 | 0 0 0 1 1 1 1
6 | 0 0 0 0 1 1 1
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDeletable,
MHAMask,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
apply_log2e_after_maskβ
comptime apply_log2e_after_mask = False
check_mask_during_decodingβ
comptime check_mask_during_decoding = True
device_typeβ
comptime device_type = SlidingWindowNonCausalMask[window_size]
mask_out_of_boundβ
comptime mask_out_of_bound = True
mask_safe_out_of_boundsβ
comptime mask_safe_out_of_bounds = True
Methodsβ
get_type_nameβ
nameβ
maskβ
def mask[dtype: DType, width: SIMDSize, *, element_type: DType = DType.uint32](self, coord: IndexList[Int(4), element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]
Returns:
statusβ
def status[*, element_type: DType = DType.uint32](self, seq_id: UInt32, tile_offset: IndexList[Int(2), element_type=element_type], tile_size: IndexList[Int(2), element_type=element_type]) -> TileMaskStatus
Returns:
TileMaskStatus
start_columnβ
def start_column[BM: Int, BN: Int, page_size: Int](self, seq_id: UInt32, row: UInt32) -> UInt32
Returns:
start_column_alignmentβ
total_itersβ
def total_iters[BM: Int, BN: Int, page_size: Int](self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> UInt32
Returns:
count_nonfull_setsβ
last_masked_set_endβ
def last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> UInt32
Returns:
masked_set_endsβ
def masked_set_ends[BM: Int, BN: Int, page_size: Int](self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, SlidingWindowNonCausalMask.count_nonfull_sets(BM, BN)]
Returns:
StaticTuple[UInt32, SlidingWindowNonCausalMask.count_nonfull_sets(BM, BN)]
nonfull_setsβ
static def nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, SlidingWindowNonCausalMask.count_nonfull_sets(BM, BN)]
Returns:
StaticTuple[TileMaskStatus, SlidingWindowNonCausalMask.count_nonfull_sets(BM, BN)]
mask_strategiesβ
static def mask_strategies[BM: Int, BN: Int]() -> StaticTuple[MaskStrategy, SlidingWindowNonCausalMask.count_nonfull_sets(BM, BN)]
Returns:
StaticTuple[MaskStrategy, SlidingWindowNonCausalMask.count_nonfull_sets(BM, BN)]
mask_bitsβ
def mask_bits(self, seq_id: UInt32, score_row: Int32, col_start: Int32, num_keys: Int32) -> UInt32
Returns:
sliding_window_sizeβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!