Skip to main content

Mojo function

apply_min_p_mask_kernel

apply_min_p_mask_kernel[dtype: DType, block_size: Int](probs: UnsafePointer[Scalar[dtype], MutExternalOrigin], min_p_arr: UnsafePointer[Float32, ImmutExternalOrigin], d: Int)

Zero out probabilities below the per-row min_p threshold.

Each block processes one batch row. Threads cooperatively find the row-wise max probability via a block reduction, compute the threshold as min_p * max_prob, and then zero any element below it.

Args: