Skip to main content

Mojo function

apply_mask

apply_mask[BN: Int, MaskType: MHAMask, ScoreModType: ScoreModTrait, //, *, use_score_mod: Bool, mask_strategy: MaskStrategy, skip_scale: Bool = False](srow: LayoutTensor[DType.float32, Layout.row_major(BN), MutAnyOrigin, address_space=AddressSpace.LOCAL], mask: MaskType, score_mod: ScoreModType, scale_log2e: Float32, *, prompt_idx: UInt32, q_head_idx: UInt32, kv_tile_start_row: Int32, max_seq_len: UInt32, num_keys: Int32, score_row: Int32)

Was this page helpful?