IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

ep_fused_dispatch_kernel_api

ep_fused_dispatch_kernel_api[token_fmt_type: TokenFormat, dispatch_dtype: DType, //, n_experts: Int, max_token_per_rank: Int, n_gpus_per_node: Int, n_nodes: Int, fused_shared_expert: Bool, target: StringSlice[StaticConstantOrigin], input_scales_wrapper: Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]] = None, skip_a2a: Bool = False, use_shmem: Bool = (n_nodes > 1), allreduce_world_size: Int = 1](token_handler: token_fmt_type, row_offsets: TileTensor[DType.uint32, address_space=row_offsets.address_space, linear_idx_type=row_offsets.linear_idx_type, element_size=row_offsets.element_size], expert_ids: TileTensor[DType.int32, address_space=expert_ids.address_space, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size], src_info: TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size], atomic_counters: TileTensor[DType.int32, address_space=atomic_counters.address_space, linear_idx_type=atomic_counters.linear_idx_type, element_size=atomic_counters.element_size], input_tokens: TileTensor[dispatch_dtype, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size], topk_ids: TileTensor[DType.int32, address_space=topk_ids.address_space, linear_idx_type=topk_ids.linear_idx_type, element_size=topk_ids.element_size], send_ptrs: TileTensor[DType.uint64, address_space=send_ptrs.address_space, linear_idx_type=send_ptrs.linear_idx_type, element_size=send_ptrs.element_size], recv_ptrs: TileTensor[DType.uint64, address_space=recv_ptrs.address_space, linear_idx_type=recv_ptrs.linear_idx_type, element_size=recv_ptrs.element_size], recv_count_ptrs: TileTensor[DType.uint64, address_space=recv_count_ptrs.address_space, linear_idx_type=recv_count_ptrs.linear_idx_type, element_size=recv_count_ptrs.element_size], context: DeviceContext)

Execute the fused Expert Parallelism dispatch kernel.

This function launches the fused dispatch_kernel from ep_comm.mojo that combines both dispatch_async and dispatch_wait functionality in a single kernel launch. It distributes input tokens to expert devices based on top-k routing decisions, then waits for all tokens to arrive and aggregates them for grouped matmul computation.

Arguments: token_handler: Token handler. Wrapper for the output token tensor. row_offsets: Row offsets for grouped matmul. expert_ids: Expert IDs for grouped matmul. src_info: Source routing information for combine phase. atomic_counters: EP kernel synchronization counters. input_tokens: Tokens to dispatch to experts. topk_ids: Expert assignments from router. send_ptrs: Send buffer pointers for each local GPU. recv_ptrs: Receive buffer pointers for each local GPU. recv_count_ptrs: Receive count buffer pointers for each local GPU. context: Device context pointer.

Parameters:

  • ​token_fmt_type (TokenFormat): Token format type.
  • ​dispatch_dtype (DType): Data type of the dispatched tokens.
  • ​n_experts (Int): Total experts across all devices.
  • ​max_token_per_rank (Int): Maximum tokens any device can send.
  • ​n_gpus_per_node (Int): GPUs per physical node.
  • ​n_nodes (Int): Number of physical nodes.
  • ​fused_shared_expert (Bool): Whether to pack shared expert inputs with routed experts' inputs.
  • ​target (StringSlice[StaticConstantOrigin]): Target.
  • ​input_scales_wrapper (Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]]): The wrapper for the input scales.
  • ​skip_a2a (Bool): Whether to skip the A2A communication. If true, we will only send tokens within the current device.
  • ​use_shmem (Bool): Whether to enable SHMEM communication.
  • ​allreduce_world_size (Int): The world size of the allreduce operation. Only needed for skip_a2a. Used to calculate the workload distribution for the shared expert (if has one).