For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
ep_combine_wait_kernel_api
ep_combine_wait_kernel_api[combine_dtype: DType, //, hidden_size: Int, top_k: Int, n_experts: Int, max_token_per_rank: Int, n_gpus_per_node: Int, n_nodes: Int, target: StringSlice[StaticConstantOrigin], router_weights_wrapper: Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]] = None, epilogue_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](output_tokens: TileTensor[combine_dtype, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size], atomic_counters: TileTensor[DType.int32, address_space=atomic_counters.address_space, linear_idx_type=atomic_counters.linear_idx_type, element_size=atomic_counters.element_size], recv_ptrs: TileTensor[DType.uint64, address_space=recv_ptrs.address_space, linear_idx_type=recv_ptrs.linear_idx_type, element_size=recv_ptrs.element_size], recv_count_ptrs: TileTensor[DType.uint64, address_space=recv_count_ptrs.address_space, linear_idx_type=recv_count_ptrs.linear_idx_type, element_size=recv_count_ptrs.element_size], context: DeviceContext, num_input_tokens: Int = -1)
Execute the Expert Parallelism combine completion kernel.
This function launches the combine_wait_kernel from ep_comm.mojo to complete the token combine phase. It waits for all inter-device communication to complete, then computes the weighted sum of routed expert outputs for each token.
Arguments: output_tokens: Final output tensor with expert results. atomic_counters: EP kernel synchronization counters. Used to coordinate between different thread blocks. recv_ptrs: Receive buffer pointers for each local GPU. recv_count_ptrs: Receive count buffer pointers for each local GPU. context: Device context pointer.
Parameters:
- βcombine_dtype (
DType): Data type for tokens during combine phase. - βhidden_size (
Int): Model hidden dimension size. - βtop_k (
Int): Number of experts each token was routed to. - βn_experts (
Int): Total experts across all devices. - βmax_token_per_rank (
Int): Maximum tokens any device can receive. - βn_gpus_per_node (
Int): GPUs per physical node. - βn_nodes (
Int): Number of physical nodes. - βtarget (
StringSlice[StaticConstantOrigin]): Target. - βrouter_weights_wrapper (
Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]]): Wrapper for the optional router weights. - βepilogue_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional elementwise epilogue function applied after computing combined output.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!