Skip to main content

Mojo function

apply_min_p_mask_kernel

apply_min_p_mask_kernel[dtype: DType, block_size: Int](probs: UnsafePointer[Scalar[dtype], MutExternalOrigin], min_p_arr: UnsafePointer[Float32, ImmutExternalOrigin], d: Int)

Zero out probabilities below the per-row min_p threshold.

Each block processes one batch row. Threads cooperatively find the row-wise max probability via a block reduction, compute the threshold as min_p * max_prob, and then zero any element below it.

Args:

  • โ€‹probs (UnsafePointer): Probability buffer [batch_size * d], modified in-place.
  • โ€‹min_p_arr (UnsafePointer): Per-row min_p values [batch_size].
  • โ€‹d (Int): Vocabulary size (row length).

Was this page helpful?