IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo trait

TokenFormat

Implemented traits​

AnyType, DevicePassable, ImplicitlyDestructible

comptime members​

alignment​

comptime alignment

device_type​

comptime device_type

Indicate the type being used on accelerator devices.

dispatch_smem_size​

comptime dispatch_smem_size

dispatch_wait_tile_shape​

comptime dispatch_wait_tile_shape

hid_dim​

comptime hid_dim

top_k​

comptime top_k

Required methods​

token_size​

static token_size() -> Int

Returns the size of the (quantized) token in bytes.

Returns:

Int

copy_token_to_send_buf​

static copy_token_to_send_buf[src_type: DType, block_size: Int, buf_addr_space: AddressSpace = AddressSpace.GENERIC](buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], src_p: UnsafePointer[Scalar[src_type], address_space=src_p.address_space], input_scale: Float32)

Copy the token to the send buffer. This function needs to be called by all threads in the block.

copy_msg_to_output_tensor​

copy_msg_to_output_tensor[buf_addr_space: AddressSpace = AddressSpace.GENERIC](self: _Self, buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], token_index: Int)

Copy the message to the output tensor. This function needs to be called by all threads in a warp.

get_type_name​

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

Provided methods​

src_info_size​

static src_info_size() -> Int

Returns the size of the source info in bytes. Currently, source info is a single int32 that stores a token's index in the original rank.

Returns:

Int

topk_info_size​

static topk_info_size() -> Int

Returns the size of the top-k info in bytes. Currently, top-k info is an array of uint16 that stores a token's top-k expert IDs.

Returns:

Int

msg_size​

static msg_size() -> Int

Returns the size of the message in bytes.

Returns:

Int

src_info_offset​

static src_info_offset() -> Int

Returns the offset of the source info in the message.

Returns:

Int

topk_info_offset​

static topk_info_offset() -> Int

Returns the offset of the top-k info in the message.

Returns:

Int

pad_expert_offsets​

pad_expert_offsets[n_groups: Int](self: _Self, row_offsets: UnsafePointer[UInt32, address_space=row_offsets.address_space])

Pad the offsets to satisfy the grouped matmul alignment requirement.

init_smem_resources​

init_smem_resources(self: _Self)

Initialize the shared memory resources for the token format.

copy_msg_tile_to_output_tensor​

copy_msg_tile_to_output_tensor[extract_topk_info_func: def(UnsafePointer[UInt8, MutExternalOrigin], Int) -> None, recv_buf_ptr_func: def(Int) -> UnsafePointer[UInt8, MutExternalOrigin], //, n_warps: Int, shared_expert_offset: Int = 0](self: _Self, expert_id: Int, expert_start_pos: Int, tile_id: Int, tile_end: Int, extract_topk_info_functor: extract_topk_info_func, recv_buf_ptr_functor: recv_buf_ptr_func)

Copy a tile of tokens from the receive buffer to the output tensor.