IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

dispatch_kernel

dispatch_kernel[input_type: DType, num_threads: Int, input_tokens_layout: TensorLayout, topk_ids_layout: TensorLayout, row_offsets_layout: TensorLayout, expert_ids_layout: TensorLayout, src_info_layout: TensorLayout, n_sms: Int, n_experts: Int, n_ranks: Int, max_tokens_per_rank: Int, p2p_world_size: Int, token_fmt_type: TokenFormat, fused_shared_expert: Bool = False, input_scales_wrapper: Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]] = None, skip_a2a: Bool = False, use_shmem: Bool = True, allreduce_world_size: Int = 1](input_tokens: TileTensor[input_type, input_tokens_layout, ImmutExternalOrigin], topk_ids: TileTensor[DType.int32, topk_ids_layout, ImmutExternalOrigin], format_handler: token_fmt_type, row_offsets: TileTensor[DType.uint32, row_offsets_layout, MutExternalOrigin], expert_ids: TileTensor[DType.int32, expert_ids_layout, MutExternalOrigin], src_info: TileTensor[DType.int32, src_info_layout, MutExternalOrigin], send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_buf_ptrs: InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size], recv_count_ptrs: InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size], ep_counters: EPLocalSyncCounters[n_experts], my_rank: Int32)

Fused dispatch kernel that combines dispatch_async and dispatch_wait functionality in a single kernel launch.

This kernel dispatches tokens to experts on remote ranks based on the top-k expert IDs, then waits for all tokens to arrive and aggregates them for grouped matmul computation.

Parameters:

  • ​input_type (DType): The type of the input tokens.
  • ​num_threads (Int): The number of threads in the block.
  • ​input_tokens_layout (TensorLayout): The layout of the input tokens.
  • ​topk_ids_layout (TensorLayout): The layout of the top-k expert IDs.
  • ​row_offsets_layout (TensorLayout): The layout of the row offsets.
  • ​expert_ids_layout (TensorLayout): The layout of the expert IDs.
  • ​src_info_layout (TensorLayout): The layout of the source token info.
  • ​n_sms (Int): The total number of SMs in the device.
  • ​n_experts (Int): The total number of experts in the model.
  • ​n_ranks (Int): The number of all devices participating in the communication.
  • ​max_tokens_per_rank (Int): The maximum number of tokens per rank.
  • ​p2p_world_size (Int): Size of a High-speed GPU interconnect group.
  • ​token_fmt_type (TokenFormat): Type conforming to TokenFormat trait that defines the token encoding scheme.
  • ​fused_shared_expert (Bool): Whether to pack the shared expert inputs with the routed experts' inputs. When enabled, input_tokens is used as the shared expert inputs.
  • ​input_scales_wrapper (Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]]): The wrapper for the input scales.
  • ​skip_a2a (Bool): Whether to skip the A2A communication. If true, we will only send tokens within the current device.
  • ​use_shmem (Bool): Whether to use the SHMEM API for the communication.
  • ​allreduce_world_size (Int): The world size of the allreduce operation. Only needed for skip_a2a. Used to calculate the workload distribution for the shared expert (if has one).

Args: