For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
EPDispatchKernel
struct EPDispatchKernel[num_threads: Int, n_sms: Int, n_experts: Int, n_ranks: Int, max_tokens_per_rank: Int, p2p_world_size: Int, token_fmt_type: TokenFormat, use_shmem: Bool = True, fused_shared_expert: Bool = False, skip_a2a: Bool = False]
Implements dispatch_async and dispatch_wait kernel logic for Expert Parallelism.
This struct encapsulates the token dispatch operations used in MoE (Mixture of Experts) models with expert parallelism. It provides methods for:
-
Async Dispatch:
monitor_and_signal_completion: Aux SMs count tokens per expert and signal completion when all tokens for an expert have been sent.copy_and_send_tokens: Comm SMs copy tokens to send buffer and transfer them to destination ranks.
-
Wait for Arrivals:
wait_for_arrivals_and_compute_offsets: Aux SMs wait for token arrivals and compute output tensor offsets. Also signals other SMs to copy the tokens to the output tensor once data is ready.copy_received_tokens_to_output: Comm SMs copy received tokens to the output tensor.
Parametersβ
- βnum_threads (
Int): The number of threads per block. - βn_sms (
Int): The total number of SMs in the device. - βn_experts (
Int): The total number of experts in the model. - βn_ranks (
Int): The number of devices participating in communication. - βmax_tokens_per_rank (
Int): The maximum number of tokens per rank. - βp2p_world_size (
Int): Size of a high-speed GPU interconnect group. - βtoken_fmt_type (
TokenFormat): Type conforming to TokenFormat trait. - βuse_shmem (
Bool): Whether to use the SHMEM API for communication. - βfused_shared_expert (
Bool): Whether to pack the shared expert inputs with the routed experts' inputs. - βskip_a2a (
Bool): Whether to skip the A2A communication. If true, we will only send tokens within the current device.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
cleanup_counter_offsetβ
comptime cleanup_counter_offset = (4 * n_experts)
hid_dimβ
comptime hid_dim = token_fmt_type.hid_dim
msg_bytesβ
comptime msg_bytes = token_fmt_type.msg_size()()
n_dispatch_async_comm_smsβ
comptime n_dispatch_async_comm_sms = (n_sms - EPDispatchKernel[num_threads, n_sms, n_experts, n_ranks, max_tokens_per_rank, p2p_world_size, token_fmt_type, use_shmem, fused_shared_expert, skip_a2a].n_signal_sms)
n_dispatch_wait_comm_smsβ
comptime n_dispatch_wait_comm_sms = (n_sms - 1)
n_local_expertsβ
comptime n_local_experts = (n_experts // n_ranks)
n_offset_smsβ
comptime n_offset_sms = 1
n_signal_smsβ
comptime n_signal_sms = ceildiv(n_experts, EPDispatchKernel[num_threads, n_sms, n_experts, n_ranks, max_tokens_per_rank, p2p_world_size, token_fmt_type, use_shmem, fused_shared_expert, skip_a2a].n_warps)
n_warpsβ
comptime n_warps = (num_threads // WARP_SIZE)
rank_prefix_offsetβ
comptime rank_prefix_offset = (2 * n_experts)
ready_flag_offsetβ
comptime ready_flag_offset = ((4 * n_experts) + 1)
send_buf_ready_offsetβ
comptime send_buf_ready_offset = ((4 * n_experts) + 2)
shared_expert_started_offsetβ
comptime shared_expert_started_offset = ((4 * n_experts) + 3)
top_kβ
comptime top_k = token_fmt_type.top_k
work_counter_offsetβ
comptime work_counter_offset = (3 * n_experts)
Methodsβ
recv_buf_layoutβ
static recv_buf_layout[out_dtype: DType = _get_index_type[Layout[*?, *?]](AddressSpace.GENERIC)](coord: Coord) -> Scalar[out_dtype]
Returns:
recv_count_layoutβ
send_buf_layoutβ
monitor_and_signal_completionβ
static monitor_and_signal_completion(topk_ids: TileTensor[DType.int32, address_space=topk_ids.address_space, linear_idx_type=topk_ids.linear_idx_type, element_size=topk_ids.element_size], recv_count_ptrs: InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size], expert_reserved_counter: UnsafePointer[Int32, MutExternalOrigin], expert_finished_counter: UnsafePointer[Int32, MutExternalOrigin], rank_completion_counter: UnsafePointer[Int32, MutExternalOrigin], my_rank: Int32)
Auxiliary SM logic for dispatch_kernel.
Counts tokens per expert and signals completion when all tokens for an expert have been sent. Each warp handles one expert.
Args:
- βtopk_ids (
TileTensor[DType.int32, address_space=topk_ids.address_space, linear_idx_type=topk_ids.linear_idx_type, element_size=topk_ids.element_size]): The top-k expert IDs for each token. - βrecv_count_ptrs (
InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size]): Array of pointers to receive count buffers. - βexpert_reserved_counter (
UnsafePointer[Int32, MutExternalOrigin]): Counter for reserved slots per expert. - βexpert_finished_counter (
UnsafePointer[Int32, MutExternalOrigin]): Counter for finished sends per expert. - βrank_completion_counter (
UnsafePointer[Int32, MutExternalOrigin]): Counter for per-rank completion tracking. - βmy_rank (
Int32): The rank of the current device.
copy_and_send_tokensβ
static copy_and_send_tokens[input_type: DType, //, input_scales_wrapper: Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]] = None](input_tokens: TileTensor[input_type, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size], topk_ids: TileTensor[DType.int32, address_space=topk_ids.address_space, linear_idx_type=topk_ids.linear_idx_type, element_size=topk_ids.element_size], send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_buf_ptrs: InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size], expert_reserved_counter: UnsafePointer[Int32, MutExternalOrigin], expert_finished_counter: UnsafePointer[Int32, MutExternalOrigin], my_rank: Int32)
Communication SM logic for dispatch_kernel.
Copies tokens to send buffer and transfers them to destination ranks. Uses direct P2P transfers for same-node destinations and SHMEM for cross-node destinations.
Args:
- βinput_tokens (
TileTensor[input_type, address_space=input_tokens.address_space, linear_idx_type=input_tokens.linear_idx_type, element_size=input_tokens.element_size]): The input tokens to be dispatched. - βtopk_ids (
TileTensor[DType.int32, address_space=topk_ids.address_space, linear_idx_type=topk_ids.linear_idx_type, element_size=topk_ids.element_size]): The top-k expert IDs for each token. - βsend_buf_p (
UnsafePointer[UInt8, MutExternalOrigin]): Pointer to the send buffer. - βrecv_buf_ptrs (
InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size]): Array of pointers to receive buffers. - βexpert_reserved_counter (
UnsafePointer[Int32, MutExternalOrigin]): Counter for reserved slots per expert. - βexpert_finished_counter (
UnsafePointer[Int32, MutExternalOrigin]): Counter for finished sends per expert. - βmy_rank (
Int32): The rank of the current device.
wait_for_arrivals_and_compute_offsetsβ
static wait_for_arrivals_and_compute_offsets(format_handler: token_fmt_type, row_offsets: TileTensor[DType.uint32, address_space=row_offsets.address_space, linear_idx_type=row_offsets.linear_idx_type, element_size=row_offsets.element_size], expert_ids: TileTensor[DType.int32, address_space=expert_ids.address_space, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size], recv_count_p: UnsafePointer[UInt64, MutExternalOrigin], atomic_counter: UnsafePointer[Int32, MutExternalOrigin], my_rank: Int32, reserved_shared_expert_tokens: UInt32 = UInt32(0))
Auxiliary SM logic for dispatch_wait_kernel.
Waits for token arrivals from all ranks and computes the output tensor offsets for each local expert. Also signals other SMs to copy the tokens to the output tensor once data is ready.
Args:
- βformat_handler (
token_fmt_type): Instance of token_fmt_type for token decoding. - βrow_offsets (
TileTensor[DType.uint32, address_space=row_offsets.address_space, linear_idx_type=row_offsets.linear_idx_type, element_size=row_offsets.element_size]): Output row offsets for grouped matmul. - βexpert_ids (
TileTensor[DType.int32, address_space=expert_ids.address_space, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size]): Output expert IDs for grouped matmul. - βrecv_count_p (
UnsafePointer[UInt64, MutExternalOrigin]): Pointer to receive count buffer. - βatomic_counter (
UnsafePointer[Int32, MutExternalOrigin]): Atomic counter for synchronization. - βmy_rank (
Int32): The rank of the current device. - βreserved_shared_expert_tokens (
UInt32): The number of tokens reserved for the shared expert.
copy_received_tokens_to_outputβ
static copy_received_tokens_to_output(format_handler: token_fmt_type, row_offsets: TileTensor[DType.uint32, address_space=row_offsets.address_space, linear_idx_type=row_offsets.linear_idx_type, element_size=row_offsets.element_size], src_info: TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size], recv_buf_p: UnsafePointer[UInt8, MutExternalOrigin], atomic_counter: UnsafePointer[Int32, MutExternalOrigin], my_rank: Int32)
Communication SM logic for dispatch_wait_kernel.
Copies received tokens from the receive buffer to the output tensor. Each SM is assigned to one local expert and dynamically claims tiles via per-expert atomic counters. Tokens within a tile may come from multiple source ranks; rank boundaries are resolved via the within-expert prefix sums written by the auxiliary SM.
Args:
- βformat_handler (
token_fmt_type): Instance of token_fmt_type for token decoding. - βrow_offsets (
TileTensor[DType.uint32, address_space=row_offsets.address_space, linear_idx_type=row_offsets.linear_idx_type, element_size=row_offsets.element_size]): Output row offsets for grouped matmul. - βsrc_info (
TileTensor[DType.int32, address_space=src_info.address_space, linear_idx_type=src_info.linear_idx_type, element_size=src_info.element_size]): Output tensor for source token info. - βrecv_buf_p (
UnsafePointer[UInt8, MutExternalOrigin]): Pointer to the receive buffer. - βatomic_counter (
UnsafePointer[Int32, MutExternalOrigin]): Atomic counter for synchronization. - βmy_rank (
Int32): The rank of the current device.
pack_shared_expert_inputsβ
static pack_shared_expert_inputs(format_handler: token_fmt_type, send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], fused_se_counter: UnsafePointer[Int32, MutExternalOrigin], shared_expert_token_count: Int)
Copies already-quantized shared expert tokens from send_buf to output.
Waits for dispatch_async signal SMs to indicate all tokens have been written to the send buffer, then uses tile-based copy via copy_msg_tile_to_output_tensor. Only SMs needed for the copy participate.
Args:
- βformat_handler (
token_fmt_type): Instance of token_fmt_type for token decoding. - βsend_buf_p (
UnsafePointer[UInt8, MutExternalOrigin]): Pointer to the send buffer containing serialized tokens. - βfused_se_counter (
UnsafePointer[Int32, MutExternalOrigin]): Pointer to the two fused shared expert atomic counters (send_buf_ready at [0], started at [1]). - βshared_expert_token_count (
Int): Number of shared expert tokens to copy.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!