Mojo struct
Softmax
struct Softmax[dtype: DType, score_layout_by_mma_unit: Layout, block_layout_by_warp: Layout, warp_layout: Layout, fragment_layout: Layout, use_exp2: Bool = False]
Fields
- rowmax_tensor (
Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].RowMaxTensorType): - rowsum_tensor (
Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].RowSumTensorType): - score_frag_rowmax (
Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].ScoreFragTensorType): - score_frag_rowsum (
Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].ScoreFragTensorType): - correction (
Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].ScoreFragTensorType):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
exp_function
comptime exp_function = _exp2_concrete[?, ?] if use_exp2 else _exp_concrete[?, ?]
frag_is_row_vector
comptime frag_is_row_vector = (Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows == 1)
frag_num_cols
comptime frag_num_cols = fragment_layout.shape[1].value()
frag_num_rows
comptime frag_num_rows = fragment_layout.shape[0].value()
frag_size
comptime frag_size = (Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows * Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_cols)
num_colwise_lanes
comptime num_colwise_lanes = SIMD(warp_layout.shape[0].value())
num_colwise_tiles
comptime num_colwise_tiles = score_layout_by_mma_unit.shape[0].value()
num_colwise_warps
comptime num_colwise_warps = block_layout_by_warp.shape[0].value()
num_m_mmas
comptime num_m_mmas = score_layout_by_mma_unit.shape[0].value()
num_rows_per_thread
comptime num_rows_per_thread = (Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_colwise_tiles * Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows)
num_rowwise_lanes
comptime num_rowwise_lanes = SIMD(warp_layout.shape[1].value())
num_rowwise_tiles
comptime num_rowwise_tiles = score_layout_by_mma_unit.shape[1].value()
num_rowwise_warps
comptime num_rowwise_warps = block_layout_by_warp.shape[1].value()
num_shuffles_per_row
comptime num_shuffles_per_row = log2_floor(warp_layout.shape[1].value())
row_layout
comptime row_layout = row_major[Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_m_mmas, fragment_layout.shape[0].value()]()
RowMaxTensorType
comptime RowMaxTensorType = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
RowSumTensorType
comptime RowSumTensorType = Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].RowMaxTensorType
rowwise_lanes_stride
comptime rowwise_lanes_stride = SIMD(warp_layout.stride[1].value())
score_frag_layout
comptime score_frag_layout = row_major[Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_colwise_tiles, Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows]()
ScoreFragTensorType
comptime ScoreFragTensorType = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Methods
__init__
__init__(out self)
calculate_qk_max
calculate_qk_max(self, score: TileTensor[dtype, score.LayoutType, score.origin, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, warp_scratch.LayoutType, warp_scratch.origin, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])
calculate_qk_sum
calculate_qk_sum(self, score: TileTensor[dtype, score.LayoutType, score.origin, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, warp_scratch.LayoutType, warp_scratch.origin, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])
exp
exp[start: Int = 0, stride: Int = 1](self, score: TileTensor[dtype, score.LayoutType, score.origin, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size])
scale_rowmax
scale_rowmax(self, scale: Scalar[dtype])
Scale score_frag_rowmax by scale factor (e.g. scale * log2e).
Must be called after exp_scaled so that score_frag_rowmax is in the same units as rowmax_tensor for calculate_correction.
exp_scaled
exp_scaled[start: Int = 0, stride: Int = 1](self, score: TileTensor[dtype, score.LayoutType, score.origin, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], scale: Scalar[dtype])
Numerically stable scaled exp: exp2((score - max) * scale).
Subtracts the unscaled max before scaling, so the subtraction is exact for the maximum element (IEEE 754 guarantees a - a == 0). This avoids the precision gap in exp_fma where fma(score, scale, -scaled_max) can produce nonzero results when score == max due to independent rounding of scaled_max.
calculate_correction
calculate_correction(self)
update_output
update_output(self, output: TileTensor[dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size])
update_sum
update_sum(self)
apply_sum_correction
apply_sum_correction(self)
Apply rowsum *= correction (deferred sum rescale pattern).
update_sum_additive
update_sum_additive(self)
Additive rowsum update: rowsum += new_sum (no correction).
update_max
update_max(self)
full
full(self, output: TileTensor[dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], score: TileTensor[dtype, score.LayoutType, score.origin, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, warp_scratch.LayoutType, warp_scratch.origin, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!