Mojo function
gumbel_sampling_gpu
gumbel_sampling_gpu[dtype: DType, out_idx_type: DType, //, TemperatureLayoutType: TensorLayout = Layout[RuntimeInt[DType.int64], ComptimeInt[1]], SeedLayoutType: TensorLayout = Layout[RuntimeInt[DType.int64], ComptimeInt[1]]](ctx: DeviceContext, input: TileTensor[dtype, input.LayoutType, input.origin, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_shape_types=input.element_shape_types], out_idxs: TileTensor[out_idx_type, out_idxs.LayoutType, out_idxs.origin, address_space=out_idxs.address_space, linear_idx_type=out_idxs.linear_idx_type, element_shape_types=out_idxs.element_shape_types], temperature: Optional[TileTensor[DType.float32, TemperatureLayoutType, ImmutAnyOrigin]] = None, seed: Optional[TileTensor[DType.uint64, SeedLayoutType, ImmutAnyOrigin]] = None)
Gumbel sampling using the Gumbel-max trick for categorical distributions.
Applies Gumbel(0,1) noise to input logits, then selects the argmax. This is mathematically equivalent to sampling from softmax(logits/temperature) but avoids expensive softmax computation.
Args:
- ctx (
DeviceContext): Device context for GPU operations. - input (
TileTensor): Input logits tensor [batch, vocab_size]. - out_idxs (
TileTensor): Output tensor for sampled indices [batch, 1]. - temperature (
Optional): Optional per-token temperature scaling [batch]. - seed (
Optional): Optional per-token random seeds [batch] for reproducibility.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!