IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

EPCombineKernel

struct EPCombineKernel[num_threads: Int, n_sms: Int, top_k: Int, n_experts: Int, n_ranks: Int, msg_bytes: Int, max_tokens_per_rank: Int, p2p_world_size: Int, use_shmem: Bool = True, fused_shared_expert: Bool = False, skip_a2a: Bool = False]

Implements combine_async and combine_wait kernel logic for Expert Parallelism.

This struct encapsulates the token combine operations used in MoE (Mixture of Experts) models with expert parallelism. It provides methods for:

  1. Async Combine:

    • send_tokens_back: Send processed tokens back to their original ranks.
  2. Wait for Arrivals:

    • wait_for_all_arrivals: Aux SMs wait for all tokens to arrive.
    • reduce_and_copy_to_output: Comm SMs reduce and copy tokens to output.

Parameters​

  • ​num_threads (Int): The number of threads per block.
  • ​n_sms (Int): The total number of SMs in the device.
  • ​top_k (Int): The number of selected experts per token.
  • ​n_experts (Int): The total number of experts in the model.
  • ​n_ranks (Int): The number of devices participating in communication.
  • ​msg_bytes (Int): The number of bytes per token message.
  • ​max_tokens_per_rank (Int): The maximum number of tokens per rank.
  • ​p2p_world_size (Int): Size of a high-speed GPU interconnect group.
  • ​use_shmem (Bool): Whether to use the SHMEM API for communication.
  • ​fused_shared_expert (Bool): Whether to filter out the shared expert's outputs.
  • ​skip_a2a (Bool): Whether to skip the A2A communication. If true, we will only receive tokens from the current device.

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

n_local_experts​

comptime n_local_experts = (n_experts // n_ranks)

n_reduce_sms​

comptime n_reduce_sms = (n_sms - 1)

n_wait_sms​

comptime n_wait_sms = 1

n_warps​

comptime n_warps = (num_threads // WARP_SIZE)

Methods​

send_buf_layout​

static send_buf_layout[out_dtype: DType = _get_index_type[Layout[*?, *?]](AddressSpace.GENERIC)](coord: Coord) -> Scalar[out_dtype]

Returns:

Scalar[out_dtype]

recv_buf_layout​

static recv_buf_layout(coord: Coord) -> Int32

Returns:

Int32

recv_count_layout​

static recv_count_layout(coord: Coord) -> Int32

Returns:

Int32

copy_shared_expert_outputs​

static copy_shared_expert_outputs[input_type: DType, //](input_tokens: TileTensor[input_type, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size], output_tokens: TileTensor[input_type, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size])

Copies shared expert outputs to the output tensor.

This method copies the shared expert's output tokens from the input tensor to the output tensor when fused_shared_expert is enabled.

Args:

send_tokens_back​

static send_tokens_back[input_type: DType, //](input_tokens: TileTensor[input_type, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size], src_info: TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size], send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_buf_ptrs: InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size], recv_count_ptrs: InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size], atomic_counter: UnsafePointer[Int32, MutExternalOrigin], rank_completion_counter: UnsafePointer[Int32, MutExternalOrigin], my_rank: Int32)

Send processed tokens back to their original ranks.

Each SM handles one expert-rank pair, sending all tokens for that pair back to the original rank. Uses direct P2P transfers for same-node destinations and SHMEM for cross-node destinations.

Args:

wait_for_all_arrivals​

static wait_for_all_arrivals(recv_count_p: UnsafePointer[UInt64, MutExternalOrigin], atomic_counter: UnsafePointer[Int32, MutExternalOrigin])

Auxiliary SM logic for combine_wait_kernel.

Waits for all tokens to arrive from all ranks, then signals other SMs that they can start copying tokens to the output tensor.

Args:

reduce_and_copy_to_output​

static reduce_and_copy_to_output[output_type: DType, router_weights_wrapper: Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]] = None, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](output_tokens: TileTensor[output_type, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size], recv_buf_p: UnsafePointer[UInt8, MutExternalOrigin], atomic_counter: UnsafePointer[Int32, MutExternalOrigin], my_rank: Int32, topk_ids_p: Optional[UnsafePointer[Int32, ImmutExternalOrigin]] = None)

Communication SM logic for combine_wait_kernel.

Copies received tokens to the output tensor, optionally applying router weights and reduction across top-k experts.

Args: