For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
dispatch_kernel
dispatch_kernel[input_type: DType, num_threads: Int, input_tokens_layout: TensorLayout, topk_ids_layout: TensorLayout, row_offsets_layout: TensorLayout, expert_ids_layout: TensorLayout, src_info_layout: TensorLayout, n_sms: Int, n_experts: Int, n_ranks: Int, max_tokens_per_rank: Int, p2p_world_size: Int, token_fmt_type: TokenFormat, fused_shared_expert: Bool = False, input_scales_wrapper: Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]] = None, skip_a2a: Bool = False, use_shmem: Bool = True, allreduce_world_size: Int = 1](input_tokens: TileTensor[input_type, input_tokens_layout, ImmutExternalOrigin], topk_ids: TileTensor[DType.int32, topk_ids_layout, ImmutExternalOrigin], format_handler: token_fmt_type, row_offsets: TileTensor[DType.uint32, row_offsets_layout, MutExternalOrigin], expert_ids: TileTensor[DType.int32, expert_ids_layout, MutExternalOrigin], src_info: TileTensor[DType.int32, src_info_layout, MutExternalOrigin], send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_buf_ptrs: InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size], recv_count_ptrs: InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size], ep_counters: EPLocalSyncCounters[n_experts], my_rank: Int32)
Fused dispatch kernel that combines dispatch_async and dispatch_wait functionality in a single kernel launch.
This kernel dispatches tokens to experts on remote ranks based on the top-k expert IDs, then waits for all tokens to arrive and aggregates them for grouped matmul computation.
Parameters:
- βinput_type (
DType): The type of the input tokens. - βnum_threads (
Int): The number of threads in the block. - βinput_tokens_layout (
TensorLayout): The layout of the input tokens. - βtopk_ids_layout (
TensorLayout): The layout of the top-k expert IDs. - βrow_offsets_layout (
TensorLayout): The layout of the row offsets. - βexpert_ids_layout (
TensorLayout): The layout of the expert IDs. - βsrc_info_layout (
TensorLayout): The layout of the source token info. - βn_sms (
Int): The total number of SMs in the device. - βn_experts (
Int): The total number of experts in the model. - βn_ranks (
Int): The number of all devices participating in the communication. - βmax_tokens_per_rank (
Int): The maximum number of tokens per rank. - βp2p_world_size (
Int): Size of a High-speed GPU interconnect group. - βtoken_fmt_type (
TokenFormat): Type conforming to TokenFormat trait that defines the token encoding scheme. - βfused_shared_expert (
Bool): Whether to pack the shared expert inputs with the routed experts' inputs. When enabled, input_tokens is used as the shared expert inputs. - βinput_scales_wrapper (
Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]]): The wrapper for the input scales. - βskip_a2a (
Bool): Whether to skip the A2A communication. If true, we will only send tokens within the current device. - βuse_shmem (
Bool): Whether to use the SHMEM API for the communication. - βallreduce_world_size (
Int): The world size of the allreduce operation. Only needed for skip_a2a. Used to calculate the workload distribution for the shared expert (if has one).
Args:
- βinput_tokens (
TileTensor[input_type, input_tokens_layout, ImmutExternalOrigin]): The input tokens to be dispatched. Also used as shared expert inputs when fused_shared_expert is True. - βtopk_ids (
TileTensor[DType.int32, topk_ids_layout, ImmutExternalOrigin]): The top-k expert IDs for each token. - βformat_handler (
token_fmt_type): Instance of token_fmt_type that performs token decoding and manages output tensor writes. - βrow_offsets (
TileTensor[DType.uint32, row_offsets_layout, MutExternalOrigin]): The row offsets to be updated. Will be consumed by thegrouped_matmulkernel. - βexpert_ids (
TileTensor[DType.int32, expert_ids_layout, MutExternalOrigin]): The expert IDs to be updated. Will be consumed by thegrouped_matmulkernel. - βsrc_info (
TileTensor[DType.int32, src_info_layout, MutExternalOrigin]): The source token info to be updated. Once the expert computation is complete, tokens will be sent back to the original rank using information in this tensor. - βsend_buf_p (
UnsafePointer[UInt8, MutExternalOrigin]): The pointer to the send buffer. Need to be allocated usingshmem_allocifuse_shmemis True. - βrecv_buf_ptrs (
InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size]): An array of pointers to the receive buffers for each device in the p2p world. - βrecv_count_ptrs (
InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size]): An array of pointers to the receive count buffers for each device in the p2p world. - βep_counters (
EPLocalSyncCounters[n_experts]): EP atomic counters for kernel synchronization. - βmy_rank (
Int32): The rank of the current device.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!