IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SoftmaxRDNA

struct SoftmaxRDNA[dtype: DType, num_m_mmas: Int, num_n_mmas: Int, num_warps_m: Int, num_warps_n: Int, use_exp2: Bool = False]

Fields​

  • ​rowmax_tensor (SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].RowMaxTensorType):
  • ​rowsum_tensor (SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].RowSumTensorType):
  • ​score_frag_rowmax (SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].ScoreFragTensorType):
  • ​score_frag_rowsum (SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].ScoreFragTensorType):
  • ​correction (SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].ScoreFragTensorType):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

exp_function​

comptime exp_function = _exp2_concrete[_, _] if use_exp2 else _exp_concrete[_, _]

frag_is_row_vector​

comptime frag_is_row_vector = True

frag_num_rows​

comptime frag_num_rows = ComptimeInt[1].static_value

frag_size​

comptime frag_size = SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].FragmentLayoutT.static_product

FragmentLayoutT​

comptime FragmentLayoutT = Layout[*?, *?]

num_colwise_lanes​

comptime num_colwise_lanes = UInt32(16)

num_colwise_tiles​

comptime num_colwise_tiles = num_m_mmas

num_colwise_warps​

comptime num_colwise_warps = num_warps_m

num_rowwise_lanes​

comptime num_rowwise_lanes = UInt32(2)

num_rowwise_tiles​

comptime num_rowwise_tiles = num_n_mmas

num_rowwise_warps​

comptime num_rowwise_warps = num_warps_n

row_layout​

comptime row_layout = row_major[num_m_mmas, SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].frag_num_rows]()

RowMaxTensorType​

comptime RowMaxTensorType = TileTensor[dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

RowSumTensorType​

comptime RowSumTensorType = SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].RowMaxTensorType

rowwise_lanes_stride​

comptime rowwise_lanes_stride = UInt32(16)

score_frag_layout​

comptime score_frag_layout = row_major[SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].num_colwise_tiles, SoftmaxRDNA[dtype, num_m_mmas, num_n_mmas, num_warps_m, num_warps_n, use_exp2].frag_num_rows]()

ScoreFragTensorType​

comptime ScoreFragTensorType = TileTensor[dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

warp_cols​

comptime warp_cols = ComptimeInt[2].static_value

warp_rows​

comptime warp_rows = ComptimeInt[16].static_value

WarpLayoutT​

comptime WarpLayoutT = Layout[*?, *?]

Methods​

__init__​

def __init__(out self)

calculate_qk_max​

def calculate_qk_max(self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])

calculate_qk_sum​

def calculate_qk_sum(self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])

exp​

def exp(self, score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size])

calculate_correction​

def calculate_correction(self)

update_output​

def update_output(self, output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size])

update_sum​

def update_sum(self)

update_max​

def update_max(self)

full​

def full(self, output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], score: TileTensor[dtype, address_space=score.address_space, linear_idx_type=score.linear_idx_type, element_size=score.element_size], warp_scratch: TileTensor[dtype, address_space=warp_scratch.address_space, linear_idx_type=warp_scratch.linear_idx_type, element_size=warp_scratch.element_size])

Single-pass online softmax iteration: max -> exp -> sum -> correction -> update output -> update max/sum.