Skip to main content

Mojo function

gumbel_sampling_gpu

gumbel_sampling_gpu[dtype: DType, out_idx_type: DType, input_layout: Layout, //](ctx: DeviceContext, input: LayoutTensor[dtype, input_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], out_idxs: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], temperature: OptionalReg[LayoutTensor[DType.float32, Layout.row_major(-1), MutableAnyOrigin]] = None, seed: OptionalReg[LayoutTensor[DType.uint64, Layout.row_major(-1), MutableAnyOrigin]] = None)

Gumbel sampling using the Gumbel-max trick for categorical distributions.

Applies Gumbel(0,1) noise to input logits, then selects the argmax. This is mathematically equivalent to sampling from softmax(logits/temperature) but avoids expensive softmax computation.

Args:

  • ctx (DeviceContext): Device context for GPU operations.
  • input (LayoutTensor): Input logits tensor [batch, vocab_size].
  • out_idxs (LayoutTensor): Output tensor for sampled indices [batch, 1].
  • temperature (OptionalReg): Optional per-token temperature scaling [batch].
  • seed (OptionalReg): Optional per-token random seeds [batch] for reproducibility.

Was this page helpful?