For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
combine_kernel
combine_kernel[input_type: DType, num_threads: Int, input_tokens_layout: TensorLayout, src_info_layout: TensorLayout, output_tokens_layout: TensorLayout, n_sms: Int, top_k: Int, n_experts: Int, n_ranks: Int, msg_bytes: Int, max_tokens_per_rank: Int, p2p_world_size: Int, router_weights_wrapper: Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]] = None, fused_shared_expert: Bool = False, epilogue_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, skip_a2a: Bool = False, use_shmem: Bool = True, allreduce_world_size: Int = 1](input_tokens: TileTensor[input_type, input_tokens_layout, ImmutExternalOrigin], src_info: TileTensor[DType.int32, src_info_layout, ImmutExternalOrigin], output_tokens: TileTensor[input_type, output_tokens_layout, MutExternalOrigin], send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_buf_ptrs: InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size], recv_count_ptrs: InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size], ep_counters: EPLocalSyncCounters[n_experts], topk_ids_p: Optional[UnsafePointer[Int32, ImmutExternalOrigin]], my_rank: Int32)
Fused combine kernel that combines combine_async and combine_wait functionality in a single kernel launch.
This kernel sends processed tokens back to their original ranks, then waits for all tokens to arrive and computes the weighted sum of routed expert outputs for each token.
For fused_shared_expert mode, the shared expert outputs are added to the reduced routed expert outputs using an elementwise lambda. This requires router_weights_wrapper to be provided (output must be reduced).
Parameters:
- βinput_type (
DType): The type of the input/output tokens. - βnum_threads (
Int): The number of threads in the block. - βinput_tokens_layout (
TensorLayout): The layout of the input tokens. - βsrc_info_layout (
TensorLayout): The layout of the source token info. - βoutput_tokens_layout (
TensorLayout): The layout of the output tokens. - βn_sms (
Int): The total number of SMs in the device. - βtop_k (
Int): The number of selected experts per token. - βn_experts (
Int): The total number of experts in the model. - βn_ranks (
Int): The number of all devices participating in the communication. - βmsg_bytes (
Int): The number of bytes per token message. - βmax_tokens_per_rank (
Int): The maximum number of tokens per rank. - βp2p_world_size (
Int): Size of a High-speed GPU interconnect group. - βrouter_weights_wrapper (
Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]]): The wrapper for the router weights. If provided, all routed experts' outputs for a token will be weighted and summed. REQUIRED when fused_shared_expert is True. - βfused_shared_expert (
Bool): Whether to add the shared expert's output to the combined output. Requires router_weights_wrapper to be provided. - βepilogue_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional elementwise epilogue function applied after computing combined output. If provided, this function is called with coordinates and values instead of directly storing to output. - βskip_a2a (
Bool): Whether to skip the A2A communication. If true, we will only send tokens within the current device. - βuse_shmem (
Bool): Whether to use the SHMEM API for the communication. - βallreduce_world_size (
Int): The world size of the allreduce operation. Only needed for skip_a2a. Used to calculate the workload distribution for the shared expert (if has one).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!