Mojo struct
Softmax
struct Softmax[dtype: DType, num_m_mmas: Int, num_n_mmas: Int, num_warps_m: Int, num_warps_n: Int, mma_m: Int, use_exp2: Bool = False]
Fieldsβ
- βrowmax_tensor (
Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].RowMaxTensorType): - βrowsum_tensor (
Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].RowSumTensorType): - βscore_frag_rowmax (
Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].ScoreFragTensorType): - βscore_frag_rowsum (
Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].ScoreFragTensorType): - βcorrection (
Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].ScoreFragTensorType):
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
exp_functionβ
comptime exp_function = _exp2_concrete[?, ?] if use_exp2 else _exp_concrete[?, ?]
frag_is_row_vectorβ
comptime frag_is_row_vector = True
frag_num_rowsβ
comptime frag_num_rows = ComptimeInt[1].static_value
frag_sizeβ
comptime frag_size = Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].FragmentLayoutT.static_product
FragmentLayoutTβ
comptime FragmentLayoutT = Layout[*?, *?]
num_colwise_lanesβ
comptime num_colwise_lanes = SIMD(Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].warp_rows)
num_colwise_tilesβ
comptime num_colwise_tiles = num_m_mmas
num_colwise_warpsβ
comptime num_colwise_warps = num_warps_m
num_rows_per_threadβ
comptime num_rows_per_thread = (Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].num_colwise_tiles * Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].frag_num_rows)
num_rowwise_lanesβ
comptime num_rowwise_lanes = SIMD(Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].warp_cols)
num_rowwise_tilesβ
comptime num_rowwise_tiles = num_n_mmas
num_rowwise_warpsβ
comptime num_rowwise_warps = num_warps_n
num_shuffles_per_rowβ
comptime num_shuffles_per_row = log2_floor(Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].warp_cols)
row_layoutβ
comptime row_layout = row_major[num_m_mmas, Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].frag_num_rows]()
RowMaxTensorTypeβ
comptime RowMaxTensorType = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
RowSumTensorTypeβ
comptime RowSumTensorType = Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].RowMaxTensorType
rowwise_lanes_strideβ
comptime rowwise_lanes_stride = SIMD(ComptimeInt[(ComptimeInt[32 if (mma_m == 32) else 16].static_value * ComptimeInt[1].static_value)].static_value)
score_frag_layoutβ
comptime score_frag_layout = row_major[Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].num_colwise_tiles, Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].frag_num_rows]()
ScoreFragTensorTypeβ
comptime ScoreFragTensorType = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
warp_colsβ
comptime warp_cols = ComptimeInt[2 if (mma_m == 32) else 4].static_value
warp_rowsβ
comptime warp_rows = ComptimeInt[32 if (mma_m == 32) else 16].static_value
WarpLayoutTβ
comptime WarpLayoutT = Layout[*?, *?]
Methodsβ
__init__β
__init__(out self)
calculate_qk_maxβ
calculate_qk_max(self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])
calculate_qk_sumβ
calculate_qk_sum(self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])
expβ
exp[start: Int = 0, stride: Int = 1](self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size])
scale_rowmaxβ
scale_rowmax(self, scale: Scalar[dtype])
Scale score_frag_rowmax by scale factor (e.g. scale * log2e).
Must be called after exp_scaled so that score_frag_rowmax is in the same units as rowmax_tensor for calculate_correction.
exp_scaledβ
exp_scaled[start: Int = 0, stride: Int = 1](self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], scale: Scalar[dtype])
Numerically stable scaled exp: exp2((score - max) * scale).
Subtracts the unscaled max before scaling, so the subtraction is exact for the maximum element (IEEE 754 guarantees a - a == 0). This avoids the precision gap in exp_fma where fma(score, scale, -scaled_max) can produce nonzero results when score == max due to independent rounding of scaled_max.
calculate_correctionβ
calculate_correction(self)
update_outputβ
update_output(self, output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size])
update_sumβ
update_sum(self)
apply_sum_correctionβ
apply_sum_correction(self)
Apply rowsum *= correction (deferred sum rescale pattern).
update_sum_additiveβ
update_sum_additive(self)
Additive rowsum update: rowsum += new_sum (no correction).
update_maxβ
update_max(self)
fullβ
full(self, output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!