IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

softmax_with_temperature

def softmax_with_temperature[dtype: DType, temp_dtype: DType = DType.float32, TempLayoutType: TensorLayout = Layout[*?, *?]](ctx: DeviceContext, input: TileTensor[dtype, Storage=input.Storage, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], output: TileTensor[dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], temperature: Scalar[temp_dtype] = 1, temperature_arr: Optional[TileTensor[temp_dtype, TempLayoutType, ImmutAnyOrigin]] = None)

GPU softmax with per-row temperature scaling.

Computes softmax(logits / T) where T can be a scalar or a per-row array. When temperature_arr is provided, each row uses its own temperature value. Falls back to the scalar temperature for rows without an array entry.

Parameters:

  • ​dtype (DType): The data type of the input and output tensors.
  • ​temp_dtype (DType): The data type for temperature values (default float32).
  • ​TempLayoutType (TensorLayout): The layout type for the optional temperature array.

Args: