For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
OrMask
struct OrMask[T: MHAMask, S: MHAMask, //, lhs: T, rhs: S]
Mask that's the OR of two masks. If either mask masks off an element, the element is masked off.
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDeletable,
MHAMask,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
apply_log2e_after_maskβ
comptime apply_log2e_after_mask = T.apply_log2e_after_mask or S.apply_log2e_after_mask
check_mask_during_decodingβ
comptime check_mask_during_decoding = T.check_mask_during_decoding or S.check_mask_during_decoding
device_typeβ
comptime device_type = OrMask[lhs, rhs]
mask_out_of_boundβ
comptime mask_out_of_bound = T.mask_out_of_bound and S.mask_out_of_bound
mask_safe_out_of_boundsβ
comptime mask_safe_out_of_bounds = T.mask_safe_out_of_bounds and S.mask_safe_out_of_bounds
Methodsβ
get_type_nameβ
nameβ
maskβ
def mask[dtype: DType, width: SIMDSize, //, *, element_type: DType = DType.uint32](self, coord: IndexList[Int(4), element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]
Returns:
statusβ
def status[*, element_type: DType = DType.uint32](self, seq_id: UInt32, tile_offset: IndexList[Int(2), element_type=element_type], tile_size: IndexList[Int(2), element_type=element_type]) -> TileMaskStatus
Returns:
TileMaskStatus
start_columnβ
def start_column[BM: Int, BN: Int, page_size: Int](self, seq_id: UInt32, row: UInt32) -> UInt32
Returns:
start_column_alignmentβ
total_itersβ
def total_iters[BM: Int, BN: Int, page_size: Int](self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> UInt32
Returns:
count_nonfull_setsβ
last_masked_set_endβ
def last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> UInt32
Returns:
masked_set_endsβ
def masked_set_ends[BM: Int, BN: Int, page_size: Int](self, seq_id: UInt32, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, OrMask.count_nonfull_sets(BM, BN)]
Returns:
nonfull_setsβ
static def nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, OrMask.count_nonfull_sets(BM, BN)]
Returns:
StaticTuple[TileMaskStatus, OrMask.count_nonfull_sets(BM, BN)]
mask_strategiesβ
static def mask_strategies[BM: Int, BN: Int]() -> StaticTuple[MaskStrategy, OrMask.count_nonfull_sets(BM, BN)]
Returns:
StaticTuple[MaskStrategy, OrMask.count_nonfull_sets(BM, BN)]
mask_bitsβ
def mask_bits(self, seq_id: UInt32, score_row: Int32, col_start: Int32, num_keys: Int32) -> UInt32
Returns:
sliding_window_sizeβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!