IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

ep_combine_async_kernel_api

ep_combine_async_kernel_api[combine_dtype: DType, hidden_size: Int, top_k: Int, n_experts: Int, max_token_per_rank: Int, n_gpus_per_node: Int, n_nodes: Int, target: StringSlice[StaticConstantOrigin], use_shmem: Bool = (n_nodes > 1)](atomic_counters: TileTensor[DType.int32, address_space=atomic_counters.address_space, linear_idx_type=atomic_counters.linear_idx_type, element_size=atomic_counters.element_size], input_tokens: TileTensor[combine_dtype, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size], src_info: TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size], send_ptrs: TileTensor[DType.uint64, address_space=send_ptrs.address_space, linear_idx_type=send_ptrs.linear_idx_type, element_size=send_ptrs.element_size], recv_ptrs: TileTensor[DType.uint64, address_space=recv_ptrs.address_space, linear_idx_type=recv_ptrs.linear_idx_type, element_size=recv_ptrs.element_size], recv_count_ptrs: TileTensor[DType.uint64, address_space=recv_count_ptrs.address_space, linear_idx_type=recv_count_ptrs.linear_idx_type, element_size=recv_count_ptrs.element_size], context: DeviceContext)

Execute the Expert Parallelism combine kernel.

This function launches the combine_async_kernel from ep_comm.mojo to initiate sending expert outputs back to their original devices. The kernel uses source routing information to determine destinations. This kernel might also filter out the shared expert's outputs and store them in a separate tensor. In multi-node scenarios, all the communication buffers need to be allocated using shmem_malloc.

Arguments: output_tokens: Output tokens for the shared experts. atomic_counters: EP kernel synchronization counters. Used to coordinate between different thread blocks. input_tokens: Expert output tokens to send back to original devices. src_info: Source routing information from dispatch phase. send_ptrs: Send buffer pointers for each local GPU. recv_ptrs: Receive buffer pointers for each local GPU. recv_count_ptrs: Receive count buffer pointers for each local GPU. context: Device context pointer.

Parameters:

  • ​combine_dtype (DType): Data type for tokens during combine phase.
  • ​hidden_size (Int): Model hidden dimension size.
  • ​top_k (Int): Number of experts each token was routed to.
  • ​n_experts (Int): Total experts across all devices.
  • ​max_token_per_rank (Int): Maximum tokens any device can send.
  • ​n_gpus_per_node (Int): GPUs per physical node.
  • ​n_nodes (Int): Number of physical nodes.
  • ​target (StringSlice[StaticConstantOrigin]): Target.
  • ​use_shmem (Bool): Whether to enable SHMEM communication.