For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
ep_dispatch_async_kernel_api
ep_dispatch_async_kernel_api[token_fmt_type: TokenFormat, n_experts: Int, max_token_per_rank: Int, n_gpus_per_node: Int, n_nodes: Int, target: StringSlice[StaticConstantOrigin], input_scales_wrapper: Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]] = None, use_shmem: Bool = (n_nodes > 1)](atomic_counters: TileTensor[DType.int32, address_space=atomic_counters.address_space, linear_idx_type=atomic_counters.linear_idx_type, element_size=atomic_counters.element_size], input_tokens: TileTensor[address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size], topk_ids: TileTensor[DType.int32, address_space=topk_ids.address_space, linear_idx_type=topk_ids.linear_idx_type, element_size=topk_ids.element_size], send_ptrs: TileTensor[DType.uint64, address_space=send_ptrs.address_space, linear_idx_type=send_ptrs.linear_idx_type, element_size=send_ptrs.element_size], recv_ptrs: TileTensor[DType.uint64, address_space=recv_ptrs.address_space, linear_idx_type=recv_ptrs.linear_idx_type, element_size=recv_ptrs.element_size], recv_count_ptrs: TileTensor[DType.uint64, address_space=recv_count_ptrs.address_space, linear_idx_type=recv_count_ptrs.linear_idx_type, element_size=recv_count_ptrs.element_size], context: DeviceContext)
Execute the Expert Parallelism async dispatch kernel.
This function launches the dispatch_async_kernel from ep_comm.mojo to
initiate token distribution across expert devices. In multi-node
scenarios, all the communication buffers need to be allocated using
shmem_malloc.
Arguments: atomic_counters: EP kernel synchronization counters. input_tokens: Tokens to dispatch to experts. topk_ids: Expert assignments from router. send_ptrs: Send buffer pointers for each local GPU. recv_ptrs: Receive buffer pointers for each local GPU. recv_count_ptrs: Receive count buffer pointers for each local GPU. context: Device context pointer
Parameters:
- βtoken_fmt_type (
TokenFormat): Token format type. - βn_experts (
Int): Total experts across all devices. - βmax_token_per_rank (
Int): Maximum tokens any device can send. - βn_gpus_per_node (
Int): GPUs per physical node. - βn_nodes (
Int): Number of physical nodes. - βtarget (
StringSlice[StaticConstantOrigin]): Target. - βinput_scales_wrapper (
Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]]): The wrapper for the input scales. - βuse_shmem (
Bool): Whether to enable SHMEM communication.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!