For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
dispatch_wait_kernel
dispatch_wait_kernel[num_threads: Int, row_offsets_layout: TensorLayout, expert_ids_layout: TensorLayout, src_info_layout: TensorLayout, n_sms: Int, n_experts: Int, n_ranks: Int, max_tokens_per_rank: Int, token_fmt_type: TokenFormat, input_scales_wrapper: Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]] = None](format_handler: token_fmt_type, row_offsets: TileTensor[DType.uint32, row_offsets_layout, MutExternalOrigin], expert_ids: TileTensor[DType.int32, expert_ids_layout, MutExternalOrigin], src_info: TileTensor[DType.int32, src_info_layout, MutExternalOrigin], recv_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_count_p: UnsafePointer[UInt64, MutExternalOrigin], ep_counters: EPLocalSyncCounters[n_experts], my_rank: Int32)
This kernel is called after the dispatch_kernel to complete the communication. It will keep polling the receive count buffer, and once the count is no longer MAX_FINITE, it can confirm that the communication is complete from a remote rank.
The kernel will also aggregate the tokens from all the experts, and store them in the output tensor using a ragged representation.
Parameters:
- βnum_threads (
Int): The number of threads in the block. - βrow_offsets_layout (
TensorLayout): The layout of the row offsets. - βexpert_ids_layout (
TensorLayout): The layout of the expert IDs. - βsrc_info_layout (
TensorLayout): The layout of the source token info. - βn_sms (
Int): The total number of SMs in the device. - βn_experts (
Int): The number of experts in the device. - βn_ranks (
Int): The number of ranks. - βmax_tokens_per_rank (
Int): The maximum number of tokens per rank. - βtoken_fmt_type (
TokenFormat): Type conforming to TokenFormat trait that defines the token encoding scheme. - βinput_scales_wrapper (
Optional[def[dtype: DType](Int) capturing -> Scalar[dtype]]): The wrapper for the input scales.
Args:
- βformat_handler (
token_fmt_type): Instance of token_fmt_type that performs token decoding and manages output tensor writes. - βrow_offsets (
TileTensor[DType.uint32, row_offsets_layout, MutExternalOrigin]): The row offsets to be updated. Will be consumed by thegrouped_matmulkernel. - βexpert_ids (
TileTensor[DType.int32, expert_ids_layout, MutExternalOrigin]): The expert IDs to be updated. Will be consumed by thegrouped_matmulkernel. - βsrc_info (
TileTensor[DType.int32, src_info_layout, MutExternalOrigin]): The source token info to be updated. Once the expert computation is complete, tokens will be send back to the original rank using information in this tensor. - βrecv_buf_p (
UnsafePointer[UInt8, MutExternalOrigin]): The pointer to the receive buffer. Need to be allocated usingshmem_allocifuse_shmemis True. The underlying buffer is of shape(n_local_experts, n_ranks, max_tokens_per_rank, msg_bytes). - βrecv_count_p (
UnsafePointer[UInt64, MutExternalOrigin]): The pointer to the receive count buffer. Need to be allocated usingshmem_allocifuse_shmemis True. The underlying buffer is of shape(n_local_experts, n_ranks). - βep_counters (
EPLocalSyncCounters[n_experts]): EP atomic counters for kernel synchronization. - βmy_rank (
Int32): The rank of the current device.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!