IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

dispatch_wait_kernel

dispatch_wait_kernel[num_threads: Int, row_offsets_layout: TensorLayout, expert_ids_layout: TensorLayout, src_info_layout: TensorLayout, n_sms: Int, n_experts: Int, n_ranks: Int, max_tokens_per_rank: Int, token_fmt_type: TokenFormat, input_scales_wrapper: Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]] = None](format_handler: token_fmt_type, row_offsets: TileTensor[DType.uint32, row_offsets_layout, MutExternalOrigin], expert_ids: TileTensor[DType.int32, expert_ids_layout, MutExternalOrigin], src_info: TileTensor[DType.int32, src_info_layout, MutExternalOrigin], recv_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_count_p: UnsafePointer[UInt64, MutExternalOrigin], ep_counters: EPLocalSyncCounters[n_experts], my_rank: Int32)

This kernel is called after the dispatch_kernel to complete the communication. It will keep polling the receive count buffer, and once the count is no longer MAX_FINITE, it can confirm that the communication is complete from a remote rank.

The kernel will also aggregate the tokens from all the experts, and store them in the output tensor using a ragged representation.

Parameters:

  • ​num_threads (Int): The number of threads in the block.
  • ​row_offsets_layout (TensorLayout): The layout of the row offsets.
  • ​expert_ids_layout (TensorLayout): The layout of the expert IDs.
  • ​src_info_layout (TensorLayout): The layout of the source token info.
  • ​n_sms (Int): The total number of SMs in the device.
  • ​n_experts (Int): The number of experts in the device.
  • ​n_ranks (Int): The number of ranks.
  • ​max_tokens_per_rank (Int): The maximum number of tokens per rank.
  • ​token_fmt_type (TokenFormat): Type conforming to TokenFormat trait that defines the token encoding scheme.
  • ​input_scales_wrapper (Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]]): The wrapper for the input scales.

Args: