IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

BlockwiseFP8TokenFormat

struct BlockwiseFP8TokenFormat[fp8_dtype: DType, scales_dtype: DType, output_layout: TensorLayout, scales_layout: TensorLayout, //, _hid_dim: Int, _top_k: Int, _alignment: Int = 0]

Fields​

  • ​output_tokens (BlockwiseFP8TokenFormat[_hid_dim, _top_k, _alignment].TensorType):
  • ​output_scales (BlockwiseFP8TokenFormat[_hid_dim, _top_k, _alignment].ScalesTensorType):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TokenFormat, TrivialRegisterPassable

comptime members​

alignment​

comptime alignment = _alignment if _alignment.__bool__() else get_device_alignment()

device_type​

comptime device_type = BlockwiseFP8TokenFormat[_hid_dim, _top_k, _alignment]

dispatch_smem_size​

comptime dispatch_smem_size = 0

dispatch_wait_tile_shape​

comptime dispatch_wait_tile_shape = Tuple(128, 1)

expert_m_padding​

comptime expert_m_padding = (16 // size_of[scales_dtype]())

group_size​

comptime group_size = 128

hid_dim​

comptime hid_dim = _hid_dim

ScalesTensorType​

comptime ScalesTensorType = TileTensor[scales_dtype, scales_layout, MutExternalOrigin]

TensorType​

comptime TensorType = TileTensor[fp8_dtype, output_layout, MutExternalOrigin]

top_k​

comptime top_k = _top_k

Methods​

__init__​

__init__(output_tokens: TileTensor[fp8_dtype, output_layout, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size], output_scales: TileTensor[scales_dtype, scales_layout, address_space=output_scales.address_space, linear_idx_type=output_scales.linear_idx_type, element_size=output_scales.element_size]) -> Self

get_type_name​

static get_type_name() -> String

Returns:

String

fp8_quant_size​

static fp8_quant_size() -> Int

Returns:

Int

scales_size​

static scales_size() -> Int

Returns:

Int

token_size​

static token_size() -> Int

Returns:

Int

scales_offset​

static scales_offset() -> Int

Returns:

Int

pad_expert_offsets​

pad_expert_offsets[n_groups: Int](self, row_offsets: UnsafePointer[UInt32, address_space=row_offsets.address_space])

The mojo blockwise FP8 grouped matmul requires each group's m to be aligned to the expert_m_padding. This function updates the row_offsets tensor to satisfy this requirement.

For example, if the expert_m_padding is 4, and the row_offsets tensor is [0, 10, 20, 30, 40], the function will update the row_offsets tensor to [0, 12, 24, 36, 48].

copy_token_to_send_buf​

static copy_token_to_send_buf[src_type: DType, block_size: Int, buf_addr_space: AddressSpace = AddressSpace.GENERIC](buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], src_p: UnsafePointer[Scalar[src_type], address_space=src_p.address_space], input_scale: Float32)

copy_msg_to_output_tensor​

copy_msg_to_output_tensor[buf_addr_space: AddressSpace = AddressSpace.GENERIC](self, buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], token_index: Int)