For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo trait
TokenFormat
Implemented traitsβ
AnyType,
DevicePassable,
ImplicitlyDestructible
comptime membersβ
alignmentβ
comptime alignment
device_typeβ
comptime device_type
Indicate the type being used on accelerator devices.
dispatch_smem_sizeβ
comptime dispatch_smem_size
dispatch_wait_tile_shapeβ
comptime dispatch_wait_tile_shape
hid_dimβ
comptime hid_dim
top_kβ
comptime top_k
Required methodsβ
token_sizeβ
copy_token_to_send_bufβ
static copy_token_to_send_buf[src_type: DType, block_size: Int, buf_addr_space: AddressSpace = AddressSpace.GENERIC](buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], src_p: UnsafePointer[Scalar[src_type], address_space=src_p.address_space], input_scale: Float32)
Copy the token to the send buffer. This function needs to be called by all threads in the block.
copy_msg_to_output_tensorβ
copy_msg_to_output_tensor[buf_addr_space: AddressSpace = AddressSpace.GENERIC](self: _Self, buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], token_index: Int)
Copy the message to the output tensor. This function needs to be called by all threads in a warp.
get_type_nameβ
static get_type_name() -> String
Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.
Returns:
String: The host type's name.
Provided methodsβ
src_info_sizeβ
static src_info_size() -> Int
Returns the size of the source info in bytes. Currently, source info is a single int32 that stores a token's index in the original rank.
Returns:
topk_info_sizeβ
static topk_info_size() -> Int
Returns the size of the top-k info in bytes. Currently, top-k info is an array of uint16 that stores a token's top-k expert IDs.
Returns:
msg_sizeβ
src_info_offsetβ
topk_info_offsetβ
pad_expert_offsetsβ
pad_expert_offsets[n_groups: Int](self: _Self, row_offsets: UnsafePointer[UInt32, address_space=row_offsets.address_space])
Pad the offsets to satisfy the grouped matmul alignment requirement.
init_smem_resourcesβ
init_smem_resources(self: _Self)
Initialize the shared memory resources for the token format.
copy_msg_tile_to_output_tensorβ
copy_msg_tile_to_output_tensor[extract_topk_info_func: def(UnsafePointer[UInt8, MutExternalOrigin], Int) -> None, recv_buf_ptr_func: def(Int) -> UnsafePointer[UInt8, MutExternalOrigin], //, n_warps: Int, shared_expert_offset: Int = 0](self: _Self, expert_id: Int, expert_start_pos: Int, tile_id: Int, tile_end: Int, extract_topk_info_functor: extract_topk_info_func, recv_buf_ptr_functor: recv_buf_ptr_func)
Copy a tile of tokens from the receive buffer to the output tensor.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!