Mojo function
softmax_with_temperature
softmax_with_temperature[dtype: DType, temp_dtype: DType = DType.float32, TempLayoutType: TensorLayout = Layout[*?, *?]](ctx: DeviceContext, input: TileTensor[dtype, input.LayoutType, input.origin, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], output: TileTensor[dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], temperature: Scalar[temp_dtype] = 1, temperature_arr: Optional[TileTensor[temp_dtype, TempLayoutType, ImmutAnyOrigin]] = None)
GPU softmax with per-row temperature scaling.
Computes softmax(logits / T) where T can be a scalar or a per-row array.
When temperature_arr is provided, each row uses its own temperature value.
Falls back to the scalar temperature for rows without an array entry.
Parameters:
- dtype (
DType): The data type of the input and output tensors. - temp_dtype (
DType): The data type for temperature values (default float32). - TempLayoutType (
TensorLayout): The layout type for the optional temperature array.
Args:
- ctx (
DeviceContext): Device context for kernel execution. - input (
TileTensor): Input logits tensor [batch_size, vocab_size]. - output (
TileTensor): Output probability tensor (same shape as input). - temperature (
Scalar): Scalar temperature fallback (default 1.0). - temperature_arr (
Optional): Optional per-row temperature values [batch_size].
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!