For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
ep_dispatch_wait_kernel_api
ep_dispatch_wait_kernel_api[token_fmt_type: TokenFormat, //, n_experts: Int, max_token_per_rank: Int, n_gpus_per_node: Int, n_nodes: Int, target: StringSlice[StaticConstantOrigin], input_scales_wrapper: Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]] = None](token_handler: token_fmt_type, row_offsets: TileTensor[DType.uint32, address_space=row_offsets.address_space, linear_idx_type=row_offsets.linear_idx_type, element_size=row_offsets.element_size], expert_ids: TileTensor[DType.int32, address_space=expert_ids.address_space, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size], src_info: TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size], recv_ptrs: TileTensor[DType.uint64, address_space=recv_ptrs.address_space, linear_idx_type=recv_ptrs.linear_idx_type, element_size=recv_ptrs.element_size], recv_count_ptrs: TileTensor[DType.uint64, address_space=recv_count_ptrs.address_space, linear_idx_type=recv_count_ptrs.linear_idx_type, element_size=recv_count_ptrs.element_size], atomic_counters: TileTensor[DType.int32, address_space=atomic_counters.address_space, linear_idx_type=atomic_counters.linear_idx_type, element_size=atomic_counters.element_size], context: DeviceContext, num_input_tokens: Int = -1)
Execute the Expert Parallelism dispatch completion kernel.
This function launches the dispatch_wait_kernel from ep_comm.mojo to complete the token dispatch phase. It waits for all inter-device communication to complete, then organizes the received tokens into a format suitable for grouped matmul computation.
Arguments:
token_handler: Token handler. Wrapper for the output token tensor.
row_offsets: Cumulative token counts for grouped matmul.
expert_ids: Local expert IDs for grouped matmul.
src_info: Source routing information for combine phase.
atomic_counters: EP kernel synchronization counters.
recv_ptrs: Receive buffer pointers for each local GPU.
recv_count_ptrs: Receive count buffer pointers for each local GPU.
context: Device context pointer.
num_input_tokens: Per-rank input token count for this layer. When >= 0
enables the decode-fast-path grid sizing. Default -1 keeps the
full-sm_count grid for backwards compatibility.
Parameters:
- βtoken_fmt_type (
TokenFormat): Token format type. - βn_experts (
Int): Total experts across all devices. - βmax_token_per_rank (
Int): Maximum tokens any device can send. - βn_gpus_per_node (
Int): GPUs per physical node. - βn_nodes (
Int): Number of physical nodes. - βtarget (
StringSlice[StaticConstantOrigin]): Target. - βinput_scales_wrapper (
Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]]): The wrapper for the input scales.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!