Mojo struct
MaterializedMask
struct MaterializedMask[dtype_: DType, layout_: Layout, origin_: ImmutOrigin]
Mask that's backed by a materialized tensor.
Fieldsβ
- βmask_tensor (
LayoutTensor[dtype_, layout_, origin_]): - βstart_pos (
OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin]]): - βis_multiple_of_2 (
Bool):
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
MHAMask,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
apply_log2e_after_maskβ
comptime apply_log2e_after_mask = True
check_mask_during_decodingβ
comptime check_mask_during_decoding = True
device_typeβ
comptime device_type = MaterializedMask[dtype_, layout_, origin_]
mask_out_of_boundβ
comptime mask_out_of_bound = True
mask_safe_out_of_boundsβ
comptime mask_safe_out_of_bounds = False
Methodsβ
__init__β
__init__(mask_tensor: LayoutTensor[dtype_, layout_, origin_], start_pos: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin]] = None) -> Self
get_type_nameβ
nameβ
get_start_posβ
maskβ
mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]
Returns:
statusβ
status[*, element_type: DType = DType.uint32](self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus
Returns:
TileMaskStatus
start_columnβ
total_itersβ
total_iters[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32
Returns:
count_nonfull_setsβ
last_masked_set_endβ
last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32
Returns:
masked_set_endsβ
masked_set_ends[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, MaterializedMask.count_nonfull_sets(BM, BN)]
Returns:
StaticTuple[UInt32, MaterializedMask.count_nonfull_sets(BM, BN)]
nonfull_setsβ
static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, MaterializedMask.count_nonfull_sets(BM, BN)]
Returns:
StaticTuple[TileMaskStatus, MaterializedMask.count_nonfull_sets(BM, BN)]
mask_strategiesβ
static mask_strategies[BM: Int, BN: Int]() -> StaticTuple[MaskStrategy, MaterializedMask.count_nonfull_sets(BM, BN)]
Returns:
StaticTuple[MaskStrategy, MaterializedMask.count_nonfull_sets(BM, BN)]
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!