Mojo function
device_sampling_from_prob
device_sampling_from_prob[vec_size: Int, block_size: Int, dtype: DType, deterministic: Bool = False](i: Int, d: Int, low: Float32, u: Float32, prob_vec: SIMD[DType.float32, vec_size], aggregate: Float32, sampled_id_sram: UnsafePointer[Int, sampled_id_sram.origin, address_space=AddressSpace.SHARED]) -> Tuple[Float32, Int]
Device-level sampling from probability distribution with atomic operations.
Returns:
Tuple: Tuple of (new_aggregate, thread_local_max_valid_idx).
The caller is responsible for reducing max_valid_idx across the block
after all chunks are processed.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!