Mojo function
softmax_3_pass
softmax_3_pass[simd_width: Int, dtype: DType, origins: OriginSet, input_fn_1d: fn[_simd_width: Int](Int) capturing -> SIMD[dtype, _simd_width], logsoftmax: Bool = False](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])
Performs an unbatched softmax on an input tensor using the three-pass algorithm.
The unbatched three-pass softmax is defined as:
procedure SoftmaxUnbatched(InputInput)
maxVal = -∞
denom = 0
STEP 1: find the max value in each batch
for i = 0 to N do
maxVal = max(maxVal, Input[b, i])
end for
STEP 2: compute the exponential for each batch
for i = 0 to N do
Output[b, i] = exp(Input[b, i] - maxVal)
denom += Output[b, i]
end for
STEP 3: normalize each batch
for i = 0 to N do
Output[b, i] /= denom
end for
Parameters:
- simd_width (
Int
): The simd_width to use in vectorization. - dtype (
DType
): The dtype of the input and output buffers. - origins (
OriginSet
): The OriginSet of captured arguments by the input_fn_1d. - input_fn_1d (
fn[_simd_width: Int](Int) capturing -> SIMD[dtype, _simd_width]
): The elementwise input lambda. - logsoftmax (
Bool
): Enable to apply elementwise log() to outputs after softmax.
Args:
- output (
LayoutTensor
): The output buffer in which to store the softmax values.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!