Skip to main content

Mojo function

device_sampling_from_prob

device_sampling_from_prob[vec_size: Int, block_size: Int, dtype: DType, deterministic: Bool = False](i: Int, d: Int, low: Float32, u: Float32, prob_vec: SIMD[DType.float32, vec_size], aggregate: Float32, sampled_id_sram: UnsafePointer[Int, sampled_id_sram.origin, address_space=AddressSpace.SHARED]) -> Tuple[Float32, Int]

Device-level sampling from probability distribution with atomic operations.

Returns:

Tuple: Tuple of (new_aggregate, thread_local_max_valid_idx). The caller is responsible for reducing max_valid_idx across the block after all chunks are processed.

Was this page helpful?