For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
Softmax
struct Softmax[dtype: DType, num_m_mmas: Int, num_n_mmas: Int, num_warps_m: Int, num_warps_n: Int, mma_m: Int, use_exp2: Bool = False]
Fieldsβ
- βrowmax_tensor (
Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].RowMaxTensorType): - βrowsum_tensor (
Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].RowSumTensorType): - βscore_frag_rowmax (
Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].ScoreFragTensorType): - βscore_frag_rowsum (
Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].ScoreFragTensorType): - βcorrection (
Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].ScoreFragTensorType):
Implemented traitsβ
comptime membersβ
exp_functionβ
comptime exp_function = _exp2_concrete[_, _] if use_exp2 else _exp_concrete[_, _]
frag_is_row_vectorβ
comptime frag_is_row_vector = True
frag_num_rowsβ
comptime frag_num_rows = ComptimeInt[Int(1)].static_value
frag_sizeβ
comptime frag_size = Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].FragmentLayoutT.static_product
FragmentLayoutTβ
comptime FragmentLayoutT = Layout[*?, *?]
num_colwise_lanesβ
comptime num_colwise_lanes = SIMD(Int(32) if (eq mma_m, 32) else Int(16))
num_colwise_tilesβ
comptime num_colwise_tiles = num_m_mmas
num_colwise_warpsβ
comptime num_colwise_warps = num_warps_m
num_rows_per_threadβ
comptime num_rows_per_thread = (Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].num_colwise_tiles * Int(1))
num_rowwise_lanesβ
comptime num_rowwise_lanes = SIMD(Int(2) if (eq mma_m, 32) else Int(4))
num_rowwise_tilesβ
comptime num_rowwise_tiles = num_n_mmas
num_rowwise_warpsβ
comptime num_rowwise_warps = num_warps_n
num_shuffles_per_rowβ
comptime num_shuffles_per_row = log2_floor(Int(2) if (eq mma_m, 32) else Int(4))
row_layoutβ
comptime row_layout = row_major[num_m_mmas, Int(1)]()
RowMaxTensorTypeβ
comptime RowMaxTensorType = TileTensor[dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
RowSumTensorTypeβ
comptime RowSumTensorType = Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].RowMaxTensorType
rowwise_lanes_strideβ
comptime rowwise_lanes_stride = SIMD(Int(32) if (eq mma_m, 32) else Int(16))
score_frag_layoutβ
comptime score_frag_layout = row_major[Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].num_colwise_tiles, Int(1)]()
ScoreFragTensorTypeβ
comptime ScoreFragTensorType = TileTensor[dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
warp_colsβ
comptime warp_cols = ComptimeInt[Int(2) if (eq mma_m, 32) else Int(4)].static_value
warp_rowsβ
comptime warp_rows = ComptimeInt[Int(32) if (eq mma_m, 32) else Int(16)].static_value
WarpLayoutTβ
comptime WarpLayoutT = Layout[*?, *?]
Methodsβ
__init__β
def __init__(out self)
calculate_qk_maxβ
def calculate_qk_max(self, score: TileTensor[dtype, Storage=score.Storage, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, Storage=warp_scratch.Storage, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])
calculate_qk_sumβ
def calculate_qk_sum(self, score: TileTensor[dtype, Storage=score.Storage, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, Storage=warp_scratch.Storage, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])
expβ
def exp[start: Int = Int(0), stride: Int = Int(1)](self, score: TileTensor[dtype, Storage=score.Storage, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size])
scale_rowmaxβ
def scale_rowmax(self, scale: Scalar[dtype])
Scale score_frag_rowmax by scale factor (e.g. scale * log2e).
Must be called after exp_scaled so that score_frag_rowmax is in the same units as rowmax_tensor for calculate_correction.
exp_scaledβ
def exp_scaled[start: Int = Int(0), stride: Int = Int(1)](self, score: TileTensor[dtype, Storage=score.Storage, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], scale: Scalar[dtype])
Numerically stable scaled exp: exp2((score - max) * scale).
Subtracts the unscaled max before scaling, so the subtraction is exact for the maximum element (IEEE 754 guarantees a - a == 0). This avoids the precision gap in exp_fma where fma(score, scale, -scaled_max) can produce nonzero results when score == max due to independent rounding of scaled_max.
exp_pkfmaβ
def exp_pkfma[start: Int = Int(0), stride: Int = Int(1)](self, score: TileTensor[dtype, Storage=score.Storage, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], scale: Scalar[dtype])
Scaled exp using fused score * scale + (-max * scale) form so the compiler emits v_pk_fma_f32 instead of separate add+mul.
Mirrors exp_scaled but pre-multiplies -max by scale once per row
(1 packed mul per row) and uses fma(score, scale, neg_scaled_max)
for every score pair (matches aiter's softmax inner loop).
Precision: score == max produces fma(max, scale, -scaled_max)
where scaled_max is pre-rounded; the FMA's internal max*scale may
round to a slightly different value, yielding a tiny epsilon instead
of exactly 0. exp2(epsilon) β 1 + epsilonΒ·ln(2), well within the
FP8 softmax tolerance budget (the row-sum normalization absorbs it).
calculate_correctionβ
def calculate_correction(self)
update_outputβ
def update_output(self, output: TileTensor[dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size])
update_sumβ
def update_sum(self)
apply_sum_correctionβ
def apply_sum_correction(self)
Apply rowsum *= correction (deferred sum rescale pattern).
update_sum_additiveβ
def update_sum_additive(self)
Additive rowsum update: rowsum += new_sum (no correction).
update_maxβ
def update_max(self)
fullβ
def full(self, output: TileTensor[dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], score: TileTensor[dtype, Storage=score.Storage, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, Storage=warp_scratch.Storage, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!