Skip to main content

Mojo function

softmax_with_temperature

softmax_with_temperature[dtype: DType, temp_dtype: DType = DType.float32, TempLayoutType: TensorLayout = Layout[RuntimeInt[DType.int64], ComptimeInt[1]]](ctx: DeviceContext, input: TileTensor[dtype, input.LayoutType, input.origin, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], output: TileTensor[dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], temperature: Scalar[temp_dtype] = 1, temperature_arr: Optional[TileTensor[temp_dtype, TempLayoutType, ImmutAnyOrigin]] = None)

GPU softmax with per-row temperature scaling.

Computes softmax(logits / T) where T can be a scalar or a per-row array. When temperature_arr is provided, each row uses its own temperature value. Falls back to the scalar temperature for rows without an array entry.

Args:

  • ctx (DeviceContext): Device context for kernel execution.
  • input (TileTensor): Input logits tensor [batch_size, vocab_size].
  • output (TileTensor): Output probability tensor (same shape as input).
  • temperature (Scalar): Scalar temperature fallback (default 1.0).
  • temperature_arr (Optional): Optional per-row temperature values [batch_size].

Was this page helpful?