IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

Softmax

struct Softmax[dtype: DType, num_m_mmas: Int, num_n_mmas: Int, num_warps_m: Int, num_warps_n: Int, mma_m: Int, use_exp2: Bool = False]

Fields​

  • ​rowmax_tensor (Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].RowMaxTensorType):
  • ​rowsum_tensor (Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].RowSumTensorType):
  • ​score_frag_rowmax (Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].ScoreFragTensorType):
  • ​score_frag_rowsum (Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].ScoreFragTensorType):
  • ​correction (Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].ScoreFragTensorType):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

exp_function​

comptime exp_function = _exp2_concrete[_, _] if use_exp2 else _exp_concrete[_, _]

frag_is_row_vector​

comptime frag_is_row_vector = True

frag_num_rows​

comptime frag_num_rows = ComptimeInt[Int(1)].static_value

frag_size​

comptime frag_size = Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].FragmentLayoutT.static_product

FragmentLayoutT​

comptime FragmentLayoutT = Layout[*?, *?]

num_colwise_lanes​

comptime num_colwise_lanes = SIMD(Int(32) if (eq mma_m, 32) else Int(16))

num_colwise_tiles​

comptime num_colwise_tiles = num_m_mmas

num_colwise_warps​

comptime num_colwise_warps = num_warps_m

num_rows_per_thread​

comptime num_rows_per_thread = (Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].num_colwise_tiles * Int(1))

num_rowwise_lanes​

comptime num_rowwise_lanes = SIMD(Int(2) if (eq mma_m, 32) else Int(4))

num_rowwise_tiles​

comptime num_rowwise_tiles = num_n_mmas

num_rowwise_warps​

comptime num_rowwise_warps = num_warps_n

num_shuffles_per_row​

comptime num_shuffles_per_row = log2_floor(Int(2) if (eq mma_m, 32) else Int(4))

row_layout​

comptime row_layout = row_major[num_m_mmas, Int(1)]()

RowMaxTensorType​

comptime RowMaxTensorType = TileTensor[dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

RowSumTensorType​

comptime RowSumTensorType = Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].RowMaxTensorType

rowwise_lanes_stride​

comptime rowwise_lanes_stride = SIMD(Int(32) if (eq mma_m, 32) else Int(16))

score_frag_layout​

comptime score_frag_layout = row_major[Softmax[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, mma_m, use_exp2].num_colwise_tiles, Int(1)]()

ScoreFragTensorType​

comptime ScoreFragTensorType = TileTensor[dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

warp_cols​

comptime warp_cols = ComptimeInt[Int(2) if (eq mma_m, 32) else Int(4)].static_value

warp_rows​

comptime warp_rows = ComptimeInt[Int(32) if (eq mma_m, 32) else Int(16)].static_value

WarpLayoutT​

comptime WarpLayoutT = Layout[*?, *?]

Methods​

__init__​

def __init__(out self)

calculate_qk_max​

def calculate_qk_max(self, score: TileTensor[dtype, Storage=score.Storage, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, Storage=warp_scratch.Storage, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])

calculate_qk_sum​

def calculate_qk_sum(self, score: TileTensor[dtype, Storage=score.Storage, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, Storage=warp_scratch.Storage, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])

exp​

def exp[start: Int = Int(0), stride: Int = Int(1)](self, score: TileTensor[dtype, Storage=score.Storage, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size])

scale_rowmax​

def scale_rowmax(self, scale: Scalar[dtype])

Scale score_frag_rowmax by scale factor (e.g. scale * log2e).

Must be called after exp_scaled so that score_frag_rowmax is in the same units as rowmax_tensor for calculate_correction.

exp_scaled​

def exp_scaled[start: Int = Int(0), stride: Int = Int(1)](self, score: TileTensor[dtype, Storage=score.Storage, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], scale: Scalar[dtype])

Numerically stable scaled exp: exp2((score - max) * scale).

Subtracts the unscaled max before scaling, so the subtraction is exact for the maximum element (IEEE 754 guarantees a - a == 0). This avoids the precision gap in exp_fma where fma(score, scale, -scaled_max) can produce nonzero results when score == max due to independent rounding of scaled_max.

exp_pkfma​

def exp_pkfma[start: Int = Int(0), stride: Int = Int(1)](self, score: TileTensor[dtype, Storage=score.Storage, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], scale: Scalar[dtype])

Scaled exp using fused score * scale + (-max * scale) form so the compiler emits v_pk_fma_f32 instead of separate add+mul.

Mirrors exp_scaled but pre-multiplies -max by scale once per row (1 packed mul per row) and uses fma(score, scale, neg_scaled_max) for every score pair (matches aiter's softmax inner loop).

Precision: score == max produces fma(max, scale, -scaled_max) where scaled_max is pre-rounded; the FMA's internal max*scale may round to a slightly different value, yielding a tiny epsilon instead of exactly 0. exp2(epsilon) β‰ˆ 1 + epsilonΒ·ln(2), well within the FP8 softmax tolerance budget (the row-sum normalization absorbs it).

calculate_correction​

def calculate_correction(self)

update_output​

def update_output(self, output: TileTensor[dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size])

update_sum​

def update_sum(self)

apply_sum_correction​

def apply_sum_correction(self)

Apply rowsum *= correction (deferred sum rescale pattern).

update_sum_additive​

def update_sum_additive(self)

Additive rowsum update: rowsum += new_sum (no correction).

update_max​

def update_max(self)

full​

def full(self, output: TileTensor[dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], score: TileTensor[dtype, Storage=score.Storage, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, Storage=warp_scratch.Storage, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])