IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

dispatch_async_kernel

dispatch_async_kernel[input_type: DType, num_threads: Int, input_tokens_layout: TensorLayout, topk_ids_layout: TensorLayout, n_sms: Int, n_experts: Int, n_ranks: Int, max_tokens_per_rank: Int, p2p_world_size: Int, token_fmt_type: TokenFormat, input_scales_wrapper: Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]] = None, use_shmem: Bool = True](input_tokens: TileTensor[input_type, input_tokens_layout, ImmutExternalOrigin], topk_ids: TileTensor[DType.int32, topk_ids_layout, ImmutExternalOrigin], send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_buf_ptrs: InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size], recv_count_ptrs: InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size], ep_counters: EPLocalSyncCounters[n_experts], my_rank: Int32)

Dispatch tokens to experts on remote ranks based on the top-k expert IDs. This kernel utilizes the non-blocking SHMEM API if use_shmem is True, and would return immediately after initiating the communication. The communication is considered complete after calling the dispatch_wait_kernel.

Parameters:

  • ​input_type (DType): The type of the input tokens.
  • ​num_threads (Int): The number of threads in the block.
  • ​input_tokens_layout (TensorLayout): The layout of the input tokens.
  • ​topk_ids_layout (TensorLayout): The layout of the top-k expert IDs.
  • ​n_sms (Int): The total number of SMs in the device.
  • ​n_experts (Int): The total number of experts in the model.
  • ​n_ranks (Int): The number of all devices participating in the communication.
  • ​max_tokens_per_rank (Int): The maximum number of tokens per rank.
  • ​p2p_world_size (Int): Size of a High-speed GPU interconnect group.
  • ​token_fmt_type (TokenFormat): Type conforming to TokenFormat trait that defines the token encoding scheme.
  • ​input_scales_wrapper (Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]]): The wrapper for the input scales.
  • ​use_shmem (Bool): Whether to use the SHMEM API for the communication.

Args: