Skip to main content

Mojo function

gumbel_sampling_gpu

gumbel_sampling_gpu[dtype: DType, out_idx_type: DType, //, TemperatureLayoutType: TensorLayout = Layout[*?, *?], SeedLayoutType: TensorLayout = Layout[*?, *?]](ctx: DeviceContext, input: TileTensor[dtype, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], out_idxs: TileTensor[out_idx_type, address_space=out_idxs.address_space, linear_idx_type=out_idxs.linear_idx_type, element_size=out_idxs.element_size], temperature: Optional[TileTensor[DType.float32, TemperatureLayoutType, ImmutAnyOrigin]] = None, seed: Optional[TileTensor[DType.uint64, SeedLayoutType, ImmutAnyOrigin]] = None)

Gumbel sampling using the Gumbel-max trick for categorical distributions.

Applies Gumbel(0,1) noise to input logits, then selects the argmax. This is mathematically equivalent to sampling from softmax(logits/temperature) but avoids expensive softmax computation.

Args: