Skip to main content

Mojo struct

SoftmaxRDNA

struct SoftmaxRDNA[dtype: DType, num_m_mmas: Int, num_n_mmas: Int, num_warps_m: Int, num_warps_n: Int, use_exp2: Bool = False]

Fields​

  • ​rowmax_tensor (SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].RowMaxTensorType):
  • ​rowsum_tensor (SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].RowSumTensorType):
  • ​score_frag_rowmax (SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].ScoreFragTensorType):
  • ​score_frag_rowsum (SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].ScoreFragTensorType):
  • ​correction (SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].ScoreFragTensorType):

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

exp_function​

comptime exp_function = _exp2_concrete[?, ?] if use_exp2 else _exp_concrete[?, ?]

frag_is_row_vector​

comptime frag_is_row_vector = True

frag_num_rows​

comptime frag_num_rows = ComptimeInt[1].static_value

frag_size​

comptime frag_size = SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].FragmentLayoutT.static_product

FragmentLayoutT​

comptime FragmentLayoutT = Layout[*?, *?]

num_colwise_lanes​

comptime num_colwise_lanes = SIMD(SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].warp_rows)

num_colwise_tiles​

comptime num_colwise_tiles = num_m_mmas

num_colwise_warps​

comptime num_colwise_warps = num_warps_m

num_rowwise_lanes​

comptime num_rowwise_lanes = SIMD(SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].warp_cols)

num_rowwise_tiles​

comptime num_rowwise_tiles = num_n_mmas

num_rowwise_warps​

comptime num_rowwise_warps = num_warps_n

row_layout​

comptime row_layout = row_major[num_m_mmas, SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].frag_num_rows]()

RowMaxTensorType​

comptime RowMaxTensorType = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

RowSumTensorType​

comptime RowSumTensorType = SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].RowMaxTensorType

rowwise_lanes_stride​

comptime rowwise_lanes_stride = SIMD(ComptimeInt[16].static_value)

score_frag_layout​

comptime score_frag_layout = row_major[SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].num_colwise_tiles, SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].frag_num_rows]()

ScoreFragTensorType​

comptime ScoreFragTensorType = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

warp_cols​

comptime warp_cols = ComptimeInt[2].static_value

warp_rows​

comptime warp_rows = ComptimeInt[16].static_value

WarpLayoutT​

comptime WarpLayoutT = Layout[*?, *?]

Methods​

__init__​

__init__(out self)

calculate_qk_max​

calculate_qk_max(self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])

calculate_qk_sum​

calculate_qk_sum(self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])

exp​

exp(self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size])

calculate_correction​

calculate_correction(self)

update_output​

update_output(self, output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size])

update_sum​

update_sum(self)

update_max​

update_max(self)

full​

full(self, output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])

Single-pass online softmax iteration: max -> exp -> sum -> correction -> update output -> update max/sum.