IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

CausalPaddingMask

struct CausalPaddingMask[layout_: Layout, origin_: ImmutOrigin]

Causal mask combined with padding: a position (seq_id, head, q, k) is visible only when q >= k (causal) AND k < valid_lengths[seq_id] (padding).

valid_lengths is a tensor of shape [num_seqs] with one uint32 value per sequence indicating the number of valid (non-padding) tokens.

Fields​

  • ​valid_lengths (LayoutTensor[DType.uint32, layout_, origin_]):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, MHAMask, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

apply_log2e_after_mask​

comptime apply_log2e_after_mask = False

causal_mask​

comptime causal_mask = CausalMask()

check_mask_during_decoding​

comptime check_mask_during_decoding = True

device_type​

comptime device_type = CausalPaddingMask[layout_, origin_]

mask_out_of_bound​

comptime mask_out_of_bound = is_nvidia_gpu()

mask_safe_out_of_bounds​

comptime mask_safe_out_of_bounds = True

Methods​

__init__​

def __init__(valid_lengths: LayoutTensor[DType.uint32, layout_, origin_]) -> Self

get_type_name​

static def get_type_name() -> String

Returns:

String

name​

static def name() -> String

Returns:

String

mask​

def mask[dtype: DType, width: SIMDSize, //, *, element_type: DType = DType.uint32](self, coord: IndexList[Int(4), element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]

Returns:

SIMD[dtype, width]

status​

def status[*, element_type: DType = DType.uint32](self, seq_id: UInt32, tile_offset: IndexList[Int(2), element_type=element_type], tile_size: IndexList[Int(2), element_type=element_type]) -> TileMaskStatus

Returns:

TileMaskStatus

start_column​

def start_column[BM: Int, BN: Int, page_size: Int](self, seq_id: UInt32, row: UInt32) -> UInt32

Returns:

UInt32

start_column_alignment​

static def start_column_alignment[BM: Int, BN: Int, page_size: Int]() -> Int

Returns:

Int

total_iters​

def total_iters[BM: Int, BN: Int, page_size: Int](self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> UInt32

Returns:

UInt32

count_nonfull_sets​

static def count_nonfull_sets(BM: Int, BN: Int) -> Int

Returns:

Int

last_masked_set_end​

def last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> UInt32

Returns:

UInt32

masked_set_ends​

def masked_set_ends[BM: Int, BN: Int, page_size: Int](self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, CausalPaddingMask.count_nonfull_sets(BM, BN)]

Returns:

StaticTuple[UInt32, CausalPaddingMask.count_nonfull_sets(BM, BN)]

nonfull_sets​

static def nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, CausalPaddingMask.count_nonfull_sets(BM, BN)]

Returns:

StaticTuple[TileMaskStatus, CausalPaddingMask.count_nonfull_sets(BM, BN)]

mask_strategies​

static def mask_strategies[BM: Int, BN: Int]() -> StaticTuple[MaskStrategy, CausalPaddingMask.count_nonfull_sets(BM, BN)]

Returns:

StaticTuple[MaskStrategy, CausalPaddingMask.count_nonfull_sets(BM, BN)]

mask_bits​

def mask_bits(self, seq_id: UInt32, score_row: Int32, col_start: Int32, num_keys: Int32) -> UInt32

Returns:

UInt32

sliding_window_size​

static def sliding_window_size() -> Int

Returns:

Int