For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
EPCombineKernel
struct EPCombineKernel[num_threads: Int, n_sms: Int, top_k: Int, n_experts: Int, n_ranks: Int, msg_bytes: Int, max_tokens_per_rank: Int, p2p_world_size: Int, use_shmem: Bool = True, fused_shared_expert: Bool = False, skip_a2a: Bool = False]
Implements combine_async and combine_wait kernel logic for Expert Parallelism.
This struct encapsulates the token combine operations used in MoE (Mixture of Experts) models with expert parallelism. It provides methods for:
-
Async Combine:
send_tokens_back: Send processed tokens back to their original ranks.
-
Wait for Arrivals:
wait_for_all_arrivals: Aux SMs wait for all tokens to arrive.reduce_and_copy_to_output: Comm SMs reduce and copy tokens to output.
Parametersβ
- βnum_threads (
Int): The number of threads per block. - βn_sms (
Int): The total number of SMs in the device. - βtop_k (
Int): The number of selected experts per token. - βn_experts (
Int): The total number of experts in the model. - βn_ranks (
Int): The number of devices participating in communication. - βmsg_bytes (
Int): The number of bytes per token message. - βmax_tokens_per_rank (
Int): The maximum number of tokens per rank. - βp2p_world_size (
Int): Size of a high-speed GPU interconnect group. - βuse_shmem (
Bool): Whether to use the SHMEM API for communication. - βfused_shared_expert (
Bool): Whether to filter out the shared expert's outputs. - βskip_a2a (
Bool): Whether to skip the A2A communication. If true, we will only receive tokens from the current device.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
n_local_expertsβ
comptime n_local_experts = (n_experts // n_ranks)
n_reduce_smsβ
comptime n_reduce_sms = (n_sms - 1)
n_wait_smsβ
comptime n_wait_sms = 1
n_warpsβ
comptime n_warps = (num_threads // WARP_SIZE)
Methodsβ
send_buf_layoutβ
static send_buf_layout[out_dtype: DType = _get_index_type[Layout[*?, *?]](AddressSpace.GENERIC)](coord: Coord) -> Scalar[out_dtype]
Returns:
recv_buf_layoutβ
recv_count_layoutβ
copy_shared_expert_outputsβ
static copy_shared_expert_outputs[input_type: DType, //](input_tokens: TileTensor[input_type, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size], output_tokens: TileTensor[input_type, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size])
Copies shared expert outputs to the output tensor.
This method copies the shared expert's output tokens from the input tensor to the output tensor when fused_shared_expert is enabled.
Args:
- βinput_tokens (
TileTensor[input_type, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size]): The input tokens containing shared expert outputs. - βoutput_tokens (
TileTensor[input_type, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size]): The output tensor to copy shared expert outputs to.
send_tokens_backβ
static send_tokens_back[input_type: DType, //](input_tokens: TileTensor[input_type, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size], src_info: TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size], send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_buf_ptrs: InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size], recv_count_ptrs: InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size], atomic_counter: UnsafePointer[Int32, MutExternalOrigin], rank_completion_counter: UnsafePointer[Int32, MutExternalOrigin], my_rank: Int32)
Send processed tokens back to their original ranks.
Each SM handles one expert-rank pair, sending all tokens for that pair back to the original rank. Uses direct P2P transfers for same-node destinations and SHMEM for cross-node destinations.
Args:
- βinput_tokens (
TileTensor[input_type, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size]): The tokens to be sent back. - βsrc_info (
TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size]): Source token info (original position and top-k ID). - βsend_buf_p (
UnsafePointer[UInt8, MutExternalOrigin]): Pointer to the send buffer. - βrecv_buf_ptrs (
InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size]): Array of pointers to receive buffers. - βrecv_count_ptrs (
InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size]): Array of pointers to receive count buffers. - βatomic_counter (
UnsafePointer[Int32, MutExternalOrigin]): Atomic counter for synchronization. - βrank_completion_counter (
UnsafePointer[Int32, MutExternalOrigin]): Counter for per-rank completion tracking. - βmy_rank (
Int32): The rank of the current device.
wait_for_all_arrivalsβ
static wait_for_all_arrivals(recv_count_p: UnsafePointer[UInt64, MutExternalOrigin], atomic_counter: UnsafePointer[Int32, MutExternalOrigin])
Auxiliary SM logic for combine_wait_kernel.
Waits for all tokens to arrive from all ranks, then signals other SMs that they can start copying tokens to the output tensor.
Args:
- βrecv_count_p (
UnsafePointer[UInt64, MutExternalOrigin]): Pointer to the receive count buffer. - βatomic_counter (
UnsafePointer[Int32, MutExternalOrigin]): Atomic counter for synchronization.
reduce_and_copy_to_outputβ
static reduce_and_copy_to_output[output_type: DType, router_weights_wrapper: Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]] = None, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](output_tokens: TileTensor[output_type, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size], recv_buf_p: UnsafePointer[UInt8, MutExternalOrigin], atomic_counter: UnsafePointer[Int32, MutExternalOrigin], my_rank: Int32, topk_ids_p: Optional[UnsafePointer[Int32, ImmutExternalOrigin]] = None)
Communication SM logic for combine_wait_kernel.
Copies received tokens to the output tensor, optionally applying router weights and reduction across top-k experts.
Args:
- βoutput_tokens (
TileTensor[output_type, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size]): The tensor to store the output tokens. - βrecv_buf_p (
UnsafePointer[UInt8, MutExternalOrigin]): Pointer to the receive buffer. - βatomic_counter (
UnsafePointer[Int32, MutExternalOrigin]): Atomic counter for synchronization. - βmy_rank (
Int32): The rank of the current device. - βtopk_ids_p (
Optional[UnsafePointer[Int32, ImmutExternalOrigin]]): Pointer to the top-k IDs for each token, only required if skip_a2a is True.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!