For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
combine_async_kernel
combine_async_kernel[input_type: DType, num_threads: Int, input_tokens_layout: TensorLayout, src_info_layout: TensorLayout, n_sms: Int, top_k: Int, n_experts: Int, n_ranks: Int, msg_bytes: Int, max_tokens_per_rank: Int, p2p_world_size: Int, use_shmem: Bool = True](input_tokens: TileTensor[input_type, input_tokens_layout, ImmutExternalOrigin], src_info: TileTensor[DType.int32, src_info_layout, ImmutExternalOrigin], send_buf_p: UnsafePointer[UInt8, MutExternalOrigin], recv_buf_ptrs: InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size], recv_count_ptrs: InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size], ep_counters: EPLocalSyncCounters[n_experts], my_rank: Int32)
Send tokens to the original rank based on the src_info tensor. This kernel utilizes the non-blocking SHMEM API, and would return immediately after initiating the communication. The communication is considered complete after calling the combine_wait_kernel.
Parameters:
- βinput_type (
DType): The type of the input tokens. - βnum_threads (
Int): The number of threads in the block. - βinput_tokens_layout (
TensorLayout): The layout of the input tokens. - βsrc_info_layout (
TensorLayout): The layout of the source token info. - βn_sms (
Int): The total number of SMs in the device. - βtop_k (
Int): The number of selected experts per token. - βn_experts (
Int): The total number of experts in the model. - βn_ranks (
Int): The number of all devices participating in the communication. - βmsg_bytes (
Int): This is the total number of bytes we need to send for each token. - βmax_tokens_per_rank (
Int): The maximum number of tokens per rank. - βp2p_world_size (
Int): Size of a High-speed GPU interconnect group. - βuse_shmem (
Bool): Whether to use the SHMEM API for the communication.
Args:
- βinput_tokens (
TileTensor[input_type, input_tokens_layout, ImmutExternalOrigin]): The tokens to be sent back to the original rank. - βsrc_info (
TileTensor[DType.int32, src_info_layout, ImmutExternalOrigin]): The source token info tensor of shape(n_local_experts * n_ranks * max_tokens_per_rank, 2). The first column stores a token's position in the original rank's tensor, and the second column stores the top-k ID for the token. - βsend_buf_p (
UnsafePointer[UInt8, MutExternalOrigin]): The pointer to the send buffer. Need to be allocated usingshmem_allocifuse_shmemis True. The underlying buffer is of shape(n_local_experts * n_ranks * max_tokens_per_rank, msg_bytes). - βrecv_buf_ptrs (
InlineArray[UnsafePointer[UInt8, MutExternalOrigin], p2p_world_size]): An array of pointers to the receive buffers for each device in the p2p world. Each buffer is of shape(max_tokens_per_rank, top_k, msg_bytes). Need to be allocated usingshmem_allocifuse_shmemis True. - βrecv_count_ptrs (
InlineArray[UnsafePointer[UInt64, MutExternalOrigin], p2p_world_size]): An array of pointers to the receive count buffers for each device in the p2p world. Each buffer is of shape(n_local_experts, n_ranks). - βep_counters (
EPLocalSyncCounters[n_experts]): EP atomic counters for kernel synchronization. - βmy_rank (
Int32): The rank of the current device.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!