IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

ep_fused_combine_kernel_api

ep_fused_combine_kernel_api[combine_dtype: DType, //, hidden_size: Int, top_k: Int, n_experts: Int, max_token_per_rank: Int, n_gpus_per_node: Int, n_nodes: Int, target: StringSlice[StaticConstantOrigin], router_weights_wrapper: Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]] = None, epilogue_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, fused_shared_expert: Bool = False, skip_a2a: Bool = False, use_shmem: Bool = (n_nodes > 1), allreduce_world_size: Int = 1](output_tokens: TileTensor[combine_dtype, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size], atomic_counters: TileTensor[DType.int32, address_space=atomic_counters.address_space, linear_idx_type=atomic_counters.linear_idx_type, element_size=atomic_counters.element_size], input_tokens: TileTensor[combine_dtype, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size], src_info: TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size], send_ptrs: TileTensor[DType.uint64, address_space=send_ptrs.address_space, linear_idx_type=send_ptrs.linear_idx_type, element_size=send_ptrs.element_size], recv_ptrs: TileTensor[DType.uint64, address_space=recv_ptrs.address_space, linear_idx_type=recv_ptrs.linear_idx_type, element_size=recv_ptrs.element_size], recv_count_ptrs: TileTensor[DType.uint64, address_space=recv_count_ptrs.address_space, linear_idx_type=recv_count_ptrs.linear_idx_type, element_size=recv_count_ptrs.element_size], context: DeviceContext, topk_ids_p: Optional[UnsafePointer[Int32, ImmutExternalOrigin]] = None)

Execute the fused Expert Parallelism combine kernel.

This function launches the fused combine_kernel from ep_comm.mojo that combines both combine_async and combine_wait functionality in a single kernel launch. It sends expert outputs back to their original devices, then waits for all transfers to complete and computes the weighted sum of routed expert outputs for each token.

Arguments: output_tokens: Final output tensor with expert results. atomic_counters: EP kernel synchronization counters. input_tokens: Expert output tokens to send back. src_info: Source routing information from dispatch phase. send_ptrs: Send buffer pointers for each local GPU. recv_ptrs: Receive buffer pointers for each local GPU. recv_count_ptrs: Receive count buffer pointers for each local GPU. router_weights: Router weights for the current device. context: Device context pointer.

Parameters: