Mojo function
apply_min_p_mask_kernel
apply_min_p_mask_kernel[dtype: DType, block_size: Int](probs: UnsafePointer[Scalar[dtype], MutExternalOrigin], min_p_arr: UnsafePointer[Float32, ImmutExternalOrigin], d: Int)
Zero out probabilities below the per-row min_p threshold.
Each block processes one batch row. Threads cooperatively find the
row-wise max probability via a block reduction, compute the threshold
as min_p * max_prob, and then zero any element below it.
Args:
- โprobs (
UnsafePointer): Probability buffer [batch_size * d], modified in-place. - โmin_p_arr (
UnsafePointer): Per-row min_p values [batch_size]. - โd (
Int): Vocabulary size (row length).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!