Skip to main content

Mojo function

softmax_with_temperature

softmax_with_temperature[dtype: DType, temp_dtype: DType = DType.float32, TempLayoutType: TensorLayout = Layout[*?, *?]](ctx: DeviceContext, input: TileTensor[dtype, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], temperature: Scalar[temp_dtype] = 1, temperature_arr: Optional[TileTensor[temp_dtype, TempLayoutType, ImmutAnyOrigin]] = None)

GPU softmax with per-row temperature scaling.

Computes softmax(logits / T) where T can be a scalar or a per-row array. When temperature_arr is provided, each row uses its own temperature value. Falls back to the scalar temperature for rows without an array entry.

Parameters:

  • ​dtype (DType): The data type of the input and output tensors.
  • ​temp_dtype (DType): The data type for temperature values (default float32).
  • ​TempLayoutType (TensorLayout): The layout type for the optional temperature array.

Args: