Skip to main content

Mojo function

dispatch_materialized_mask_and_score_mod

dispatch_materialized_mask_and_score_mod[score_mod_type: String, callback_fn: fn[mask_t: MHAMask, score_mod_t: ScoreModTrait](mask: mask_t, score_mod: score_mod_t) raises capturing -> None, num_heads: Int = -1](mask_nd: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], start_pos_nd: OptionalReg[NDBuffer[DType.uint32, 1, MutableAnyOrigin]] = OptionalReg[NDBuffer[DType.uint32, 1, MutableAnyOrigin]]({:i1 0, 1}))

Was this page helpful?