For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
ep_combine_async_kernel_api
ep_combine_async_kernel_api[combine_dtype: DType, hidden_size: Int, top_k: Int, n_experts: Int, max_token_per_rank: Int, n_gpus_per_node: Int, n_nodes: Int, target: StringSlice[StaticConstantOrigin], use_shmem: Bool = (n_nodes > 1)](atomic_counters: TileTensor[DType.int32, address_space=atomic_counters.address_space, linear_idx_type=atomic_counters.linear_idx_type, element_size=atomic_counters.element_size], input_tokens: TileTensor[combine_dtype, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size], src_info: TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size], send_ptrs: TileTensor[DType.uint64, address_space=send_ptrs.address_space, linear_idx_type=send_ptrs.linear_idx_type, element_size=send_ptrs.element_size], recv_ptrs: TileTensor[DType.uint64, address_space=recv_ptrs.address_space, linear_idx_type=recv_ptrs.linear_idx_type, element_size=recv_ptrs.element_size], recv_count_ptrs: TileTensor[DType.uint64, address_space=recv_count_ptrs.address_space, linear_idx_type=recv_count_ptrs.linear_idx_type, element_size=recv_count_ptrs.element_size], context: DeviceContext)
Execute the Expert Parallelism combine kernel.
This function launches the combine_async_kernel from ep_comm.mojo to
initiate sending expert outputs back to their original devices. The kernel
uses source routing information to determine destinations. This kernel might
also filter out the shared expert's outputs and store them in a separate
tensor. In multi-node scenarios, all the communication buffers need to be
allocated using shmem_malloc.
Arguments: output_tokens: Output tokens for the shared experts. atomic_counters: EP kernel synchronization counters. Used to coordinate between different thread blocks. input_tokens: Expert output tokens to send back to original devices. src_info: Source routing information from dispatch phase. send_ptrs: Send buffer pointers for each local GPU. recv_ptrs: Receive buffer pointers for each local GPU. recv_count_ptrs: Receive count buffer pointers for each local GPU. context: Device context pointer.
Parameters:
- βcombine_dtype (
DType): Data type for tokens during combine phase. - βhidden_size (
Int): Model hidden dimension size. - βtop_k (
Int): Number of experts each token was routed to. - βn_experts (
Int): Total experts across all devices. - βmax_token_per_rank (
Int): Maximum tokens any device can send. - βn_gpus_per_node (
Int): GPUs per physical node. - βn_nodes (
Int): Number of physical nodes. - βtarget (
StringSlice[StaticConstantOrigin]): Target. - βuse_shmem (
Bool): Whether to enable SHMEM communication.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!