Skip to main content

Mojo struct

MaterializedMask

struct MaterializedMask[dtype_: DType, layout_: Layout, origin_: ImmutOrigin]

Mask that's backed by a materialized tensor.

Fields​

  • ​mask_tensor (LayoutTensor[dtype_, layout_, origin_]):
  • ​start_pos (OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin]]):
  • ​is_multiple_of_2 (Bool):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, MHAMask, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

apply_log2e_after_mask​

comptime apply_log2e_after_mask = True

check_mask_during_decoding​

comptime check_mask_during_decoding = True

device_type​

comptime device_type = MaterializedMask[dtype_, layout_, origin_]

mask_out_of_bound​

comptime mask_out_of_bound = True

mask_safe_out_of_bounds​

comptime mask_safe_out_of_bounds = False

Methods​

__init__​

__init__(mask_tensor: LayoutTensor[dtype_, layout_, origin_], start_pos: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin]] = None) -> Self

get_type_name​

static get_type_name() -> String

Returns:

String

name​

static name() -> String

Returns:

String

get_start_pos​

get_start_pos(self, batch_idx: Int) -> Int

Returns:

Int

mask​

mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]

Returns:

SIMD[dtype, width]

status​

status[*, element_type: DType = DType.uint32](self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus

Returns:

TileMaskStatus

start_column​

start_column[BM: Int, BN: Int, page_size: Int](self, row: UInt32) -> UInt32

Returns:

UInt32

total_iters​

total_iters[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32

Returns:

UInt32

count_nonfull_sets​

static count_nonfull_sets(BM: Int, BN: Int) -> Int

Returns:

Int

last_masked_set_end​

last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32

Returns:

UInt32

masked_set_ends​

masked_set_ends[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, MaterializedMask.count_nonfull_sets(BM, BN)]

Returns:

StaticTuple[UInt32, MaterializedMask.count_nonfull_sets(BM, BN)]

nonfull_sets​

static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, MaterializedMask.count_nonfull_sets(BM, BN)]

Returns:

StaticTuple[TileMaskStatus, MaterializedMask.count_nonfull_sets(BM, BN)]

mask_strategies​

static mask_strategies[BM: Int, BN: Int]() -> StaticTuple[MaskStrategy, MaterializedMask.count_nonfull_sets(BM, BN)]

Returns:

StaticTuple[MaskStrategy, MaterializedMask.count_nonfull_sets(BM, BN)]