IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

softmax_2_pass

def softmax_2_pass[simd_width: Int, dtype: DType](output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], input: TileTensor[dtype, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size])

Performs an unbatched softmax on an input tensor using the two-pass online algorithm.

The unbatched two-pass online softmax is described in "Online normalizer calculation for softmax" (https://arxiv.org/abs/1805.02867) and "A full-stack search technique for domain optimized deep learning accelerators" (https://dl.acm.org/doi/abs/10.1145/3503222.3507767) and is defined as:

procedure SoftmaxUnbatched(InputInput)
  runningMax = -∞
  runningSum = 0
  STAGE 1:
  for i = 0 to N do
    newMax = max(runningMax, Input[i])
    runningSum = runningSum*exp(runningMax-newMax) + exp(Input[i]-newMax)
    runningMax = newMax
  end for
  for i = 0 to N do
    Output[i] = exp(Input[i] - runningMax) / runningSum
  end for

Parameters:

  • ​simd_width (Int): The simd_width to use in vectorization.
  • ​dtype (DType): The dtype of the input and output buffers.

Args: