For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
ep_fused_combine_kernel_api
ep_fused_combine_kernel_api[combine_dtype: DType, //, hidden_size: Int, top_k: Int, n_experts: Int, max_token_per_rank: Int, n_gpus_per_node: Int, n_nodes: Int, target: StringSlice[StaticConstantOrigin], router_weights_wrapper: Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]] = None, epilogue_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, fused_shared_expert: Bool = False, skip_a2a: Bool = False, use_shmem: Bool = (n_nodes > 1), allreduce_world_size: Int = 1](output_tokens: TileTensor[combine_dtype, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size], atomic_counters: TileTensor[DType.int32, address_space=atomic_counters.address_space, linear_idx_type=atomic_counters.linear_idx_type, element_size=atomic_counters.element_size], input_tokens: TileTensor[combine_dtype, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size], src_info: TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size], send_ptrs: TileTensor[DType.uint64, address_space=send_ptrs.address_space, linear_idx_type=send_ptrs.linear_idx_type, element_size=send_ptrs.element_size], recv_ptrs: TileTensor[DType.uint64, address_space=recv_ptrs.address_space, linear_idx_type=recv_ptrs.linear_idx_type, element_size=recv_ptrs.element_size], recv_count_ptrs: TileTensor[DType.uint64, address_space=recv_count_ptrs.address_space, linear_idx_type=recv_count_ptrs.linear_idx_type, element_size=recv_count_ptrs.element_size], context: DeviceContext, topk_ids_p: Optional[UnsafePointer[Int32, ImmutExternalOrigin]] = None)
Execute the fused Expert Parallelism combine kernel.
This function launches the fused combine_kernel from ep_comm.mojo that combines both combine_async and combine_wait functionality in a single kernel launch. It sends expert outputs back to their original devices, then waits for all transfers to complete and computes the weighted sum of routed expert outputs for each token.
Arguments: output_tokens: Final output tensor with expert results. atomic_counters: EP kernel synchronization counters. input_tokens: Expert output tokens to send back. src_info: Source routing information from dispatch phase. send_ptrs: Send buffer pointers for each local GPU. recv_ptrs: Receive buffer pointers for each local GPU. recv_count_ptrs: Receive count buffer pointers for each local GPU. router_weights: Router weights for the current device. context: Device context pointer.
Parameters:
- βcombine_dtype (
DType): Data type for tokens during combine phase. - βhidden_size (
Int): Model hidden dimension size. - βtop_k (
Int): Number of experts each token was routed to. - βn_experts (
Int): Total experts across all devices. - βmax_token_per_rank (
Int): Maximum tokens any device can receive. - βn_gpus_per_node (
Int): GPUs per physical node. - βn_nodes (
Int): Number of physical nodes. - βtarget (
StringSlice[StaticConstantOrigin]): Target. - βrouter_weights_wrapper (
Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]]): Wrapper for the optional router weights. - βepilogue_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional elementwise epilogue function applied after computing combined output. - βfused_shared_expert (
Bool): Whether to add shared expert outputs to the combined routed expert outputs. - βskip_a2a (
Bool): Whether to skip the A2A communication. If true, we will only send tokens within the current device. - βuse_shmem (
Bool): Whether to enable SHMEM communication. - βallreduce_world_size (
Int): The world size of the allreduce operation. Only needed for skip_a2a. Used to calculate the workload distribution for the shared expert (if has one).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!