IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

EPDispatchKernel

struct EPDispatchKernel[num_threads: Int, n_sms: Int, n_experts: Int, n_ranks: Int, max_tokens_per_rank: Int, p2p_world_size: Int, token_fmt_type: TokenFormat, use_shmem: Bool = True, fused_shared_expert: Bool = False, skip_a2a: Bool = False]

Implements dispatch_async and dispatch_wait kernel logic for Expert Parallelism.

This struct encapsulates the token dispatch operations used in MoE (Mixture of Experts) models with expert parallelism. It provides methods for:

  1. Async Dispatch:

    • monitor_and_signal_completion: Aux SMs count tokens per expert and signal completion when all tokens for an expert have been sent.
    • copy_and_send_tokens: Comm SMs copy tokens to send buffer and transfer them to destination ranks.
  2. Wait for Arrivals:

    • wait_for_arrivals_and_compute_offsets: Aux SMs wait for token arrivals and compute output tensor offsets. Also signals other SMs to copy the tokens to the output tensor once data is ready.
    • copy_received_tokens_to_output: Comm SMs copy received tokens to the output tensor.

Parameters​

  • ​num_threads (Int): The number of threads per block.
  • ​n_sms (Int): The total number of SMs in the device.
  • ​n_experts (Int): The total number of experts in the model.
  • ​n_ranks (Int): The number of devices participating in communication.
  • ​max_tokens_per_rank (Int): The maximum number of tokens per rank.
  • ​p2p_world_size (Int): Size of a high-speed GPU interconnect group.
  • ​token_fmt_type (TokenFormat): Type conforming to TokenFormat trait.
  • ​use_shmem (Bool): Whether to use the SHMEM API for communication.
  • ​fused_shared_expert (Bool): Whether to pack the shared expert inputs with the routed experts' inputs.
  • ​skip_a2a (Bool): Whether to skip the A2A communication. If true, we will only send tokens within the current device.

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

cleanup_counter_offset​

comptime cleanup_counter_offset = (4 * n_experts)

hid_dim​

comptime hid_dim = token_fmt_type.hid_dim

msg_bytes​

comptime msg_bytes = token_fmt_type.msg_size()()

n_dispatch_async_comm_sms​

comptime n_dispatch_async_comm_sms = (n_sms - EPDispatchKernel[num_threads, n_sms, n_experts, n_ranks, max_tokens_per_rank, p2p_world_size, token_fmt_type, use_shmem, fused_shared_expert, skip_a2a].n_signal_sms)

n_dispatch_wait_comm_sms​

comptime n_dispatch_wait_comm_sms = (n_sms - 1)

n_local_experts​

comptime n_local_experts = (n_experts // n_ranks)

n_offset_sms​

comptime n_offset_sms = 1

n_signal_sms​

comptime n_signal_sms = ceildiv(n_experts, EPDispatchKernel[num_threads, n_sms, n_experts, n_ranks, max_tokens_per_rank, p2p_world_size, token_fmt_type, use_shmem, fused_shared_expert, skip_a2a].n_warps)

n_warps​

comptime n_warps = (num_threads // WARP_SIZE)

rank_prefix_offset​

comptime rank_prefix_offset = (2 * n_experts)

ready_flag_offset​

comptime ready_flag_offset = ((4 * n_experts) + 1)

send_buf_ready_offset​

comptime send_buf_ready_offset = ((4 * n_experts) + 2)

shared_expert_started_offset​

comptime shared_expert_started_offset = ((4 * n_experts) + 3)

top_k​

comptime top_k = token_fmt_type.top_k

work_counter_offset​

comptime work_counter_offset = (3 * n_experts)

Methods​

recv_buf_layout​

static recv_buf_layout[out_dtype: DType = _get_index_type[Layout[*?, *?]](AddressSpace.GENERIC)](coord: Coord) -> Scalar[out_dtype]

Returns:

Scalar[out_dtype]

recv_count_layout​

static recv_count_layout(coord: Coord) -> Int32

Returns:

Int32

send_buf_layout​

static send_buf_layout(coord: Coord) -> Int32

Returns:

Int32

monitor_and_signal_completion​

static monitor_and_signal_completion(topk_ids: TileTensor[DType.int32, address_space=topk_ids.address_space, linear_idx_type=topk_ids.linear_idx_type, element_size=topk_ids.element_size], recv_count_ptrs: InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size], expert_reserved_counter: UnsafePointer[Int32, MutExternalOrigin], expert_finished_counter: UnsafePointer[Int32, MutExternalOrigin], rank_completion_counter: UnsafePointer[Int32, MutExternalOrigin], my_rank: Int32)

Auxiliary SM logic for dispatch_kernel.

Counts tokens per expert and signals completion when all tokens for an expert have been sent. Each warp handles one expert.

Args:

copy_and_send_tokens​

static copy_and_send_tokens[input_type: DType, //, input_scales_wrapper: Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]] = None](input_tokens: TileTensor[input_type, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size], topk_ids: TileTensor[DType.int32, address_space=topk_ids.address_space, linear_idx_type=topk_ids.linear_idx_type, element_size=topk_ids.element_size], send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_buf_ptrs: InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size], expert_reserved_counter: UnsafePointer[Int32, MutExternalOrigin], expert_finished_counter: UnsafePointer[Int32, MutExternalOrigin], my_rank: Int32)

Communication SM logic for dispatch_kernel.

Copies tokens to send buffer and transfers them to destination ranks. Uses direct P2P transfers for same-node destinations and SHMEM for cross-node destinations.

Args:

wait_for_arrivals_and_compute_offsets​

static wait_for_arrivals_and_compute_offsets(format_handler: token_fmt_type, row_offsets: TileTensor[DType.uint32, address_space=row_offsets.address_space, linear_idx_type=row_offsets.linear_idx_type, element_size=row_offsets.element_size], expert_ids: TileTensor[DType.int32, address_space=expert_ids.address_space, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size], recv_count_p: UnsafePointer[UInt64, MutExternalOrigin], atomic_counter: UnsafePointer[Int32, MutExternalOrigin], my_rank: Int32, reserved_shared_expert_tokens: UInt32 = UInt32(0))

Auxiliary SM logic for dispatch_wait_kernel.

Waits for token arrivals from all ranks and computes the output tensor offsets for each local expert. Also signals other SMs to copy the tokens to the output tensor once data is ready.

Args:

copy_received_tokens_to_output​

static copy_received_tokens_to_output(format_handler: token_fmt_type, row_offsets: TileTensor[DType.uint32, address_space=row_offsets.address_space, linear_idx_type=row_offsets.linear_idx_type, element_size=row_offsets.element_size], src_info: TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size], recv_buf_p: UnsafePointer[UInt8, MutExternalOrigin], atomic_counter: UnsafePointer[Int32, MutExternalOrigin], my_rank: Int32)

Communication SM logic for dispatch_wait_kernel.

Copies received tokens from the receive buffer to the output tensor. Each SM is assigned to one local expert and dynamically claims tiles via per-expert atomic counters. Tokens within a tile may come from multiple source ranks; rank boundaries are resolved via the within-expert prefix sums written by the auxiliary SM.

Args:

pack_shared_expert_inputs​

static pack_shared_expert_inputs(format_handler: token_fmt_type, send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], fused_se_counter: UnsafePointer[Int32, MutExternalOrigin], shared_expert_token_count: Int)

Copies already-quantized shared expert tokens from send_buf to output.

Waits for dispatch_async signal SMs to indicate all tokens have been written to the send buffer, then uses tile-based copy via copy_msg_tile_to_output_tensor. Only SMs needed for the copy participate.

Args:

  • ​format_handler (token_fmt_type): Instance of token_fmt_type for token decoding.
  • ​send_buf_p (UnsafePointer[UInt8, MutExternalOrigin]): Pointer to the send buffer containing serialized tokens.
  • ​fused_se_counter (UnsafePointer[Int32, MutExternalOrigin]): Pointer to the two fused shared expert atomic counters (send_buf_ready at [0], started at [1]).
  • ​shared_expert_token_count (Int): Number of shared expert tokens to copy.