IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

combine_kernel

combine_kernel[input_type: DType, num_threads: Int, input_tokens_layout: TensorLayout, src_info_layout: TensorLayout, output_tokens_layout: TensorLayout, n_sms: Int, top_k: Int, n_experts: Int, n_ranks: Int, msg_bytes: Int, max_tokens_per_rank: Int, p2p_world_size: Int, router_weights_wrapper: Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]] = None, fused_shared_expert: Bool = False, epilogue_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, skip_a2a: Bool = False, use_shmem: Bool = True, allreduce_world_size: Int = 1](input_tokens: TileTensor[input_type, input_tokens_layout, ImmutExternalOrigin], src_info: TileTensor[DType.int32, src_info_layout, ImmutExternalOrigin], output_tokens: TileTensor[input_type, output_tokens_layout, MutExternalOrigin], send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_buf_ptrs: InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size], recv_count_ptrs: InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size], ep_counters: EPLocalSyncCounters[n_experts], topk_ids_p: Optional[UnsafePointer[Int32, ImmutExternalOrigin]], my_rank: Int32)

Fused combine kernel that combines combine_async and combine_wait functionality in a single kernel launch.

This kernel sends processed tokens back to their original ranks, then waits for all tokens to arrive and computes the weighted sum of routed expert outputs for each token.

For fused_shared_expert mode, the shared expert outputs are added to the reduced routed expert outputs using an elementwise lambda. This requires router_weights_wrapper to be provided (output must be reduced).

Parameters:

  • ​input_type (DType): The type of the input/output tokens.
  • ​num_threads (Int): The number of threads in the block.
  • ​input_tokens_layout (TensorLayout): The layout of the input tokens.
  • ​src_info_layout (TensorLayout): The layout of the source token info.
  • ​output_tokens_layout (TensorLayout): The layout of the output tokens.
  • ​n_sms (Int): The total number of SMs in the device.
  • ​top_k (Int): The number of selected experts per token.
  • ​n_experts (Int): The total number of experts in the model.
  • ​n_ranks (Int): The number of all devices participating in the communication.
  • ​msg_bytes (Int): The number of bytes per token message.
  • ​max_tokens_per_rank (Int): The maximum number of tokens per rank.
  • ​p2p_world_size (Int): Size of a High-speed GPU interconnect group.
  • ​router_weights_wrapper (Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]]): The wrapper for the router weights. If provided, all routed experts' outputs for a token will be weighted and summed. REQUIRED when fused_shared_expert is True.
  • ​fused_shared_expert (Bool): Whether to add the shared expert's output to the combined output. Requires router_weights_wrapper to be provided.
  • ​epilogue_fn (Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional elementwise epilogue function applied after computing combined output. If provided, this function is called with coordinates and values instead of directly storing to output.
  • ​skip_a2a (Bool): Whether to skip the A2A communication. If true, we will only send tokens within the current device.
  • ​use_shmem (Bool): Whether to use the SHMEM API for the communication.
  • ​allreduce_world_size (Int): The world size of the allreduce operation. Only needed for skip_a2a. Used to calculate the workload distribution for the shared expert (if has one).