For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
combine_wait_kernel
combine_wait_kernel[output_type: DType, num_threads: Int, output_tokens_layout: TensorLayout, n_sms: Int, top_k: Int, n_experts: Int, n_ranks: Int, msg_bytes: Int, max_tokens_per_rank: Int, router_weights_wrapper: Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]] = None, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](output_tokens: TileTensor[output_type, output_tokens_layout, MutExternalOrigin], recv_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_count_p: UnsafePointer[UInt64, MutExternalOrigin], ep_counters: EPLocalSyncCounters[n_experts], my_rank: Int32)
This kernel is called after the combine_kernel to complete the communication. It will keep polling the receive count buffer, and once the count is no longer MAX_FINITE, it can confirm that the communication is complete from a remote rank.
Parameters:
- βoutput_type (
DType): The type of the output tokens. - βnum_threads (
Int): The number of threads in the block. - βoutput_tokens_layout (
TensorLayout): The layout of the output tokens. - βn_sms (
Int): The total number of SMs in the device. - βtop_k (
Int): The number of selected experts per token. - βn_experts (
Int): The number of experts in the device. - βn_ranks (
Int): The number of ranks. - βmsg_bytes (
Int): The number of bytes in the message for each token. - βmax_tokens_per_rank (
Int): The maximum number of tokens per rank. - βrouter_weights_wrapper (
Optional[def[width: Int](token_idx: Int, topk_id: Int) capturing -> SIMD[DType.float32, width]]): The wrapper for the router weights. If provided, all routed experts' outputs for a token will be weighted and summed. - βelementwise_lambda_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional output lambda function.
Args:
- βoutput_tokens (
TileTensor[output_type, output_tokens_layout, MutExternalOrigin]): The tensor to store the output tokens. - βrecv_buf_p (
UnsafePointer[UInt8, MutExternalOrigin]): The pointer to the receive buffer. Need to be allocated usingshmem_alloc. The underlying buffer is of shape(max_tokens_per_rank, top_k, msg_bytes). - βrecv_count_p (
UnsafePointer[UInt64, MutExternalOrigin]): The pointer to the receive count buffer. Need to be allocated usingshmem_allocifuse_shmemis True. The underlying buffer is of shape(n_local_experts, n_ranks). - βep_counters (
EPLocalSyncCounters[n_experts]): EP atomic counters for kernel synchronization. - βmy_rank (
Int32): The rank of the current device.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!