Mojo function
softmax_with_temperature
softmax_with_temperature[dtype: DType, temp_dtype: DType = DType.float32, TempLayoutType: TensorLayout = Layout[*?, *?]](ctx: DeviceContext, input: TileTensor[dtype, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], temperature: Scalar[temp_dtype] = 1, temperature_arr: Optional[TileTensor[temp_dtype, TempLayoutType, ImmutAnyOrigin]] = None)
GPU softmax with per-row temperature scaling.
Computes softmax(logits / T) where T can be a scalar or a per-row array.
When temperature_arr is provided, each row uses its own temperature value.
Falls back to the scalar temperature for rows without an array entry.
Parameters:
- βdtype (
DType): The data type of the input and output tensors. - βtemp_dtype (
DType): The data type for temperature values (default float32). - βTempLayoutType (
TensorLayout): The layout type for the optional temperature array.
Args:
- βctx (
DeviceContext): Device context for kernel execution. - βinput (
TileTensor[dtype, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size]): Input logits tensor [batch_size, vocab_size]. - βoutput (
TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Output probability tensor (same shape as input). - βtemperature (
Scalar[temp_dtype]): Scalar temperature fallback (default 1.0). - βtemperature_arr (
Optional[TileTensor[temp_dtype, TempLayoutType, ImmutAnyOrigin]]): Optional per-row temperature values [batch_size].
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!