Mojo function
softmax_with_temperature
softmax_with_temperature[dtype: DType, temp_dtype: DType = DType.float32, TempLayoutType: TensorLayout = Layout[RuntimeInt[DType.int64], ComptimeInt[1]]](ctx: DeviceContext, input: TileTensor[dtype, input.LayoutType, input.origin, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], output: TileTensor[dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], temperature: Scalar[temp_dtype] = 1, temperature_arr: Optional[TileTensor[temp_dtype, TempLayoutType, ImmutAnyOrigin]] = None)
GPU softmax with per-row temperature scaling.
Computes softmax(logits / T) where T can be a scalar or a per-row array.
When temperature_arr is provided, each row uses its own temperature value.
Falls back to the scalar temperature for rows without an array entry.
Args:
- ctx (
DeviceContext): Device context for kernel execution. - input (
TileTensor): Input logits tensor [batch_size, vocab_size]. - output (
TileTensor): Output probability tensor (same shape as input). - temperature (
Scalar): Scalar temperature fallback (default 1.0). - temperature_arr (
Optional): Optional per-row temperature values [batch_size].
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!