Mojo struct
SoftmaxRDNA
struct SoftmaxRDNA[dtype: DType, num_m_mmas: Int, num_n_mmas: Int, num_warps_m: Int, num_warps_n: Int, use_exp2: Bool = False]
Fieldsβ
- βrowmax_tensor (
SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].RowMaxTensorType): - βrowsum_tensor (
SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].RowSumTensorType): - βscore_frag_rowmax (
SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].ScoreFragTensorType): - βscore_frag_rowsum (
SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].ScoreFragTensorType): - βcorrection (
SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].ScoreFragTensorType):
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
exp_functionβ
comptime exp_function = _exp2_concrete[?, ?] if use_exp2 else _exp_concrete[?, ?]
frag_is_row_vectorβ
comptime frag_is_row_vector = True
frag_num_rowsβ
comptime frag_num_rows = ComptimeInt[1].static_value
frag_sizeβ
comptime frag_size = SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].FragmentLayoutT.static_product
FragmentLayoutTβ
comptime FragmentLayoutT = Layout[*?, *?]
num_colwise_lanesβ
comptime num_colwise_lanes = SIMD(SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].warp_rows)
num_colwise_tilesβ
comptime num_colwise_tiles = num_m_mmas
num_colwise_warpsβ
comptime num_colwise_warps = num_warps_m
num_rowwise_lanesβ
comptime num_rowwise_lanes = SIMD(SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].warp_cols)
num_rowwise_tilesβ
comptime num_rowwise_tiles = num_n_mmas
num_rowwise_warpsβ
comptime num_rowwise_warps = num_warps_n
row_layoutβ
comptime row_layout = row_major[num_m_mmas, SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].frag_num_rows]()
RowMaxTensorTypeβ
comptime RowMaxTensorType = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
RowSumTensorTypeβ
comptime RowSumTensorType = SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].RowMaxTensorType
rowwise_lanes_strideβ
comptime rowwise_lanes_stride = SIMD(ComptimeInt[16].static_value)
score_frag_layoutβ
comptime score_frag_layout = row_major[SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].num_colwise_tiles, SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].frag_num_rows]()
ScoreFragTensorTypeβ
comptime ScoreFragTensorType = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
warp_colsβ
comptime warp_cols = ComptimeInt[2].static_value
warp_rowsβ
comptime warp_rows = ComptimeInt[16].static_value
WarpLayoutTβ
comptime WarpLayoutT = Layout[*?, *?]
Methodsβ
__init__β
__init__(out self)
calculate_qk_maxβ
calculate_qk_max(self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])
calculate_qk_sumβ
calculate_qk_sum(self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])
expβ
exp(self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size])
calculate_correctionβ
calculate_correction(self)
update_outputβ
update_output(self, output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size])
update_sumβ
update_sum(self)
update_maxβ
update_max(self)
fullβ
full(self, output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])
Single-pass online softmax iteration: max -> exp -> sum -> correction -> update output -> update max/sum.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!