IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

softmax_3_pass

def softmax_3_pass[simd_width: Int, dtype: DType, origins: OriginSet, input_fn_1d: def[_simd_width: Int](Int) capturing -> SIMD[dtype, _simd_width], logsoftmax: Bool = False](output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size])

Performs an unbatched softmax on an input tensor using the three-pass algorithm.

The unbatched three-pass softmax is defined as:

procedure SoftmaxUnbatched(InputInput)
  maxVal = -∞
  denom = 0
  STEP 1: find the max value in each batch
  for i = 0 to N do
    maxVal = max(maxVal, Input[b, i])
  end for
  STEP 2: compute the exponential for each batch
  for i = 0 to N do
    Output[b, i] = exp(Input[b, i] - maxVal)
    denom += Output[b, i]
  end for
  STEP 3: normalize each batch
  for i = 0 to N do
    Output[b, i] /= denom
  end for

Parameters:

  • ​simd_width (Int): The simd_width to use in vectorization.
  • ​dtype (DType): The dtype of the input and output buffers.
  • ​origins (OriginSet): The OriginSet of captured arguments by the input_fn_1d.
  • ​input_fn_1d (def[_simd_width: Int](Int) capturing -> SIMD[dtype, _simd_width]): The elementwise input lambda.
  • ​logsoftmax (Bool): Enable to apply elementwise log() to outputs after softmax.

Args: