For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
dispatch_async_kernel
dispatch_async_kernel[input_type: DType, num_threads: Int, input_tokens_layout: TensorLayout, topk_ids_layout: TensorLayout, n_sms: Int, n_experts: Int, n_ranks: Int, max_tokens_per_rank: Int, p2p_world_size: Int, token_fmt_type: TokenFormat, input_scales_wrapper: Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]] = None, use_shmem: Bool = True](input_tokens: TileTensor[input_type, input_tokens_layout, ImmutExternalOrigin], topk_ids: TileTensor[DType.int32, topk_ids_layout, ImmutExternalOrigin], send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_buf_ptrs: InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size], recv_count_ptrs: InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size], ep_counters: EPLocalSyncCounters[n_experts], my_rank: Int32)
Dispatch tokens to experts on remote ranks based on the top-k expert IDs. This kernel utilizes the non-blocking SHMEM API if use_shmem is True, and would return immediately after initiating the communication. The communication is considered complete after calling the dispatch_wait_kernel.
Parameters:
- βinput_type (
DType): The type of the input tokens. - βnum_threads (
Int): The number of threads in the block. - βinput_tokens_layout (
TensorLayout): The layout of the input tokens. - βtopk_ids_layout (
TensorLayout): The layout of the top-k expert IDs. - βn_sms (
Int): The total number of SMs in the device. - βn_experts (
Int): The total number of experts in the model. - βn_ranks (
Int): The number of all devices participating in the communication. - βmax_tokens_per_rank (
Int): The maximum number of tokens per rank. - βp2p_world_size (
Int): Size of a High-speed GPU interconnect group. - βtoken_fmt_type (
TokenFormat): Type conforming to TokenFormat trait that defines the token encoding scheme. - βinput_scales_wrapper (
Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]]): The wrapper for the input scales. - βuse_shmem (
Bool): Whether to use the SHMEM API for the communication.
Args:
- βinput_tokens (
TileTensor[input_type, input_tokens_layout, ImmutExternalOrigin]): The input tokens to be dispatched. - βtopk_ids (
TileTensor[DType.int32, topk_ids_layout, ImmutExternalOrigin]): The top-k expert IDs for each token. - βsend_buf_p (
UnsafePointer[UInt8, MutExternalOrigin]): The pointer to the send buffer. The underlying buffer is of shape(max_tokens_per_rank, msg_bytes). Need to be allocated usingshmem_allocifuse_shmemis True. - βrecv_buf_ptrs (
InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size]): An array of pointers to the receive buffers for each device in the p2p world. Each buffer is of shape(n_local_experts, n_ranks, max_tokens_per_rank, msg_bytes). Need to be allocated usingshmem_allocifuse_shmemis True. - βrecv_count_ptrs (
InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size]): An array of pointers to the receive count buffers for each device in the p2p world. Each buffer is of shape(n_local_experts, n_ranks). Need to be allocated usingshmem_allocifuse_shmemis True. - βep_counters (
EPLocalSyncCounters[n_experts]): EP atomic counters for kernel synchronization. - βmy_rank (
Int32): The rank of the current device.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!