IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

ep_dispatch_wait_kernel_api

ep_dispatch_wait_kernel_api[token_fmt_type: TokenFormat, //, n_experts: Int, max_token_per_rank: Int, n_gpus_per_node: Int, n_nodes: Int, target: StringSlice[StaticConstantOrigin], input_scales_wrapper: Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]] = None](token_handler: token_fmt_type, row_offsets: TileTensor[DType.uint32, address_space=row_offsets.address_space, linear_idx_type=row_offsets.linear_idx_type, element_size=row_offsets.element_size], expert_ids: TileTensor[DType.int32, address_space=expert_ids.address_space, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size], src_info: TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size], recv_ptrs: TileTensor[DType.uint64, address_space=recv_ptrs.address_space, linear_idx_type=recv_ptrs.linear_idx_type, element_size=recv_ptrs.element_size], recv_count_ptrs: TileTensor[DType.uint64, address_space=recv_count_ptrs.address_space, linear_idx_type=recv_count_ptrs.linear_idx_type, element_size=recv_count_ptrs.element_size], atomic_counters: TileTensor[DType.int32, address_space=atomic_counters.address_space, linear_idx_type=atomic_counters.linear_idx_type, element_size=atomic_counters.element_size], context: DeviceContext, num_input_tokens: Int = -1)

Execute the Expert Parallelism dispatch completion kernel.

This function launches the dispatch_wait_kernel from ep_comm.mojo to complete the token dispatch phase. It waits for all inter-device communication to complete, then organizes the received tokens into a format suitable for grouped matmul computation.

Arguments: token_handler: Token handler. Wrapper for the output token tensor. row_offsets: Cumulative token counts for grouped matmul. expert_ids: Local expert IDs for grouped matmul. src_info: Source routing information for combine phase. atomic_counters: EP kernel synchronization counters. recv_ptrs: Receive buffer pointers for each local GPU. recv_count_ptrs: Receive count buffer pointers for each local GPU. context: Device context pointer. num_input_tokens: Per-rank input token count for this layer. When >= 0 enables the decode-fast-path grid sizing. Default -1 keeps the full-sm_count grid for backwards compatibility.

Parameters: