Skip to main content
Log in

Mojo function

softmax_3_pass

softmax_3_pass[simd_width: Int, buffer_size: Dim, type: DType, origins: origin.set, input_fn_1d: fn[Int](Int) capturing -> SIMD[type, $0]](output: NDBuffer[type, 1, origin, __init__[::Intable](buffer_size)])

Performs an unbatched softmax on an input tensor using the three-pass algorithm.

The unbatched three-pass softmax is defined as: procedure SoftmaxUnbatched(InputInput) maxVal = -∞ denom = 0 STEP 1: find the max value in each batch for i = 0 to N do maxVal = max(maxVal, Input[b, i]) end for STEP 2: compute the exponential for each batch for i = 0 to N do Output[b, i] = exp(Input[b, i] - maxVal) denom += Output[b, i] end for STEP 3: normalize each batch for i = 0 to N do Output[b, i] /= denom end for

Parameters:

  • simd_width (Int): The simd_width to use in vectorization.
  • buffer_size (Dim): The size of the input and output buffers.
  • type (DType): The type of the input and output buffers.
  • origins (origin.set): The OriginSet of captured arguments by the input_fn_1d.
  • input_fn_1d (fn[Int](Int) capturing -> SIMD[type, $0]): The elementwise input lambda.

Args:

  • output (NDBuffer[type, 1, origin, __init__[::Intable](buffer_size)]): The output buffer in which to store the softmax values.

Was this page helpful?