Skip to main content

Mojo function

softmax_3_pass

softmax_3_pass[simd_width: Int, dtype: DType, origins: OriginSet, input_fn_1d: fn[_simd_width: Int](Int) capturing -> SIMD[dtype, _simd_width], logsoftmax: Bool = False](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])

Performs an unbatched softmax on an input tensor using the three-pass algorithm.

The unbatched three-pass softmax is defined as:

procedure SoftmaxUnbatched(InputInput)
  maxVal = -∞
  denom = 0
  STEP 1: find the max value in each batch
  for i = 0 to N do
    maxVal = max(maxVal, Input[b, i])
  end for
  STEP 2: compute the exponential for each batch
  for i = 0 to N do
    Output[b, i] = exp(Input[b, i] - maxVal)
    denom += Output[b, i]
  end for
  STEP 3: normalize each batch
  for i = 0 to N do
    Output[b, i] /= denom
  end for

Parameters:

  • simd_width (Int): The simd_width to use in vectorization.
  • dtype (DType): The dtype of the input and output buffers.
  • origins (OriginSet): The OriginSet of captured arguments by the input_fn_1d.
  • input_fn_1d (fn[_simd_width: Int](Int) capturing -> SIMD[dtype, _simd_width]): The elementwise input lambda.
  • logsoftmax (Bool): Enable to apply elementwise log() to outputs after softmax.

Args:

  • output (LayoutTensor): The output buffer in which to store the softmax values.

Was this page helpful?