Skip to main content

Mojo struct

Softmax

struct Softmax[dtype: DType, score_layout_by_mma_unit: Layout, block_layout_by_warp: Layout, warp_layout: Layout, fragment_layout: Layout, use_exp2: Bool = False]

Fields

  • rowmax_tensor (Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].RowMaxTensorType):
  • rowsum_tensor (Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].RowSumTensorType):
  • score_frag_rowmax (LayoutTensor[dtype, Layout.row_major(Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_colwise_tiles, Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows), MutAnyOrigin, address_space=AddressSpace.LOCAL]):
  • score_frag_rowsum (LayoutTensor[dtype, Layout.row_major(Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_colwise_tiles, Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows), MutAnyOrigin, address_space=AddressSpace.LOCAL]):
  • correction (LayoutTensor[dtype, Layout.row_major(Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_colwise_tiles, Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows), MutAnyOrigin, address_space=AddressSpace.LOCAL]):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

exp_function

comptime exp_function = _exp2_concrete if use_exp2 else _exp_concrete

frag_is_row_vector

comptime frag_is_row_vector = (Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows == 1)

frag_num_cols

comptime frag_num_cols = fragment_layout.shape[1].value()

frag_num_rows

comptime frag_num_rows = fragment_layout.shape[0].value()

num_colwise_lanes

comptime num_colwise_lanes = SIMD(warp_layout.shape[0].value())

num_colwise_tiles

comptime num_colwise_tiles = score_layout_by_mma_unit.shape[0].value()

num_colwise_warps

comptime num_colwise_warps = block_layout_by_warp.shape[0].value()

num_m_mmas

comptime num_m_mmas = score_layout_by_mma_unit.shape[0].value()

num_rows_per_thread

comptime num_rows_per_thread = (Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_colwise_tiles * Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].frag_num_rows)

num_rowwise_lanes

comptime num_rowwise_lanes = SIMD(warp_layout.shape[1].value())

num_rowwise_tiles

comptime num_rowwise_tiles = score_layout_by_mma_unit.shape[1].value()

num_rowwise_warps

comptime num_rowwise_warps = block_layout_by_warp.shape[1].value()

num_shuffles_per_row

comptime num_shuffles_per_row = log2_floor(warp_layout.shape[1].value())

row_layout

comptime row_layout = Layout.row_major(Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].num_m_mmas, fragment_layout.shape[0].value())

RowMaxTensorType

comptime RowMaxTensorType = LayoutTensor[dtype, Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].row_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

RowSumTensorType

comptime RowSumTensorType = Softmax[dtype, score_layout_by_mma_unit, block_layout_by_warp, warp_layout, fragment_layout, use_exp2].RowMaxTensorType

rowwise_lanes_stride

comptime rowwise_lanes_stride = SIMD(warp_layout.stride[1].value())

Methods

__init__

__init__(out self)

calculate_qk_max

calculate_qk_max(self, score_reg_tile: LayoutTensor[dtype, score_reg_tile.layout, score_reg_tile.origin, address_space=score_reg_tile.address_space, element_layout=score_reg_tile.element_layout, layout_int_type=score_reg_tile.layout_int_type, linear_idx_type=score_reg_tile.linear_idx_type, masked=score_reg_tile.masked, alignment=score_reg_tile.alignment], warp_scratch: LayoutTensor[dtype, warp_scratch.layout, warp_scratch.origin, address_space=warp_scratch.address_space, element_layout=warp_scratch.element_layout, layout_int_type=warp_scratch.layout_int_type, linear_idx_type=warp_scratch.linear_idx_type, masked=warp_scratch.masked, alignment=warp_scratch.alignment])

calculate_qk_sum

calculate_qk_sum(self, score_reg_tile: LayoutTensor[dtype, score_reg_tile.layout, score_reg_tile.origin, address_space=score_reg_tile.address_space, element_layout=score_reg_tile.element_layout, layout_int_type=score_reg_tile.layout_int_type, linear_idx_type=score_reg_tile.linear_idx_type, masked=score_reg_tile.masked, alignment=score_reg_tile.alignment], warp_scratch: LayoutTensor[dtype, warp_scratch.layout, warp_scratch.origin, address_space=warp_scratch.address_space, element_layout=warp_scratch.element_layout, layout_int_type=warp_scratch.layout_int_type, linear_idx_type=warp_scratch.linear_idx_type, masked=warp_scratch.masked, alignment=warp_scratch.alignment])

exp

exp[start: Int = 0, stride: Int = 1](self, score_reg_tile: LayoutTensor[dtype, score_reg_tile.layout, score_reg_tile.origin, address_space=score_reg_tile.address_space, element_layout=score_reg_tile.element_layout, layout_int_type=score_reg_tile.layout_int_type, linear_idx_type=score_reg_tile.linear_idx_type, masked=score_reg_tile.masked, alignment=score_reg_tile.alignment])

calculate_correction

calculate_correction(self)

update_output

update_output(self, output_reg_tile: LayoutTensor[dtype, output_reg_tile.layout, output_reg_tile.origin, address_space=output_reg_tile.address_space, element_layout=output_reg_tile.element_layout, layout_int_type=output_reg_tile.layout_int_type, linear_idx_type=output_reg_tile.linear_idx_type, masked=output_reg_tile.masked, alignment=output_reg_tile.alignment])

update_sum

update_sum(self)

update_max

update_max(self)

full

full(self, output_reg_tile: LayoutTensor[dtype, output_reg_tile.layout, output_reg_tile.origin, address_space=output_reg_tile.address_space, element_layout=output_reg_tile.element_layout, layout_int_type=output_reg_tile.layout_int_type, linear_idx_type=output_reg_tile.linear_idx_type, masked=output_reg_tile.masked, alignment=output_reg_tile.alignment], score_reg_tile: LayoutTensor[dtype, score_reg_tile.layout, score_reg_tile.origin, address_space=score_reg_tile.address_space, element_layout=score_reg_tile.element_layout, layout_int_type=score_reg_tile.layout_int_type, linear_idx_type=score_reg_tile.linear_idx_type, masked=score_reg_tile.masked, alignment=score_reg_tile.alignment], warp_scratch: LayoutTensor[dtype, warp_scratch.layout, warp_scratch.origin, address_space=warp_scratch.address_space, element_layout=warp_scratch.element_layout, layout_int_type=warp_scratch.layout_int_type, linear_idx_type=warp_scratch.linear_idx_type, masked=warp_scratch.masked, alignment=warp_scratch.alignment])

Was this page helpful?