For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
EPLocalSyncCounters
struct EPLocalSyncCounters[n_experts: Int]
Manages atomic counters for EP kernel synchronization within a device.
This struct provides dedicated atomic counter space for each of the four EP kernels: dispatch_async, dispatch_wait, combine_async, and combine_wait. Each kernel has its own memory region to avoid conflicts, except dispatch_wait and combine_async which must share memory since combine_async reads data that dispatch_wait writes.
The struct is used to synchronize between thread blocks within the same device.
Memory Layout (all sizes in Int32 elements):
- dispatch_async: 2 * n_experts + MAX_GPUS_PER_NODE
- dispatch_wait/combine_async: 4 * n_experts + 4
- combine_wait: 2 * n_experts
Fieldsβ
- βptr (
UnsafePointer[Int32, MutExternalOrigin]): Base pointer to the allocated atomic counter memory.
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
device_typeβ
comptime device_type = EPLocalSyncCounters[n_experts]
Methodsβ
__init__β
__init__(ptr: UnsafePointer[Int32, address_space=ptr.address_space]) -> Self
__init__(buffer: DeviceBuffer[DType.int32]) -> Self
get_type_nameβ
dispatch_async_sizeβ
static dispatch_async_size() -> Int
Returns the size in Int32 elements needed by dispatch_async kernel.
Returns:
dispatch_wait_sizeβ
static dispatch_wait_size() -> Int
Returns the size in Int32 elements needed by dispatch_wait kernel.
Layout (see EPDispatchKernel for exact offset constants): Region A [0, 2n_experts): per expert-rank combine_async compat data Region B [2n_experts, 3n_experts): within-expert rank prefix sums Region C [3n_experts, 4n_experts): per-expert work counters (only first n_local_experts entries used; rest unused) Region D [4n_experts]: cleanup ref counter Region E [4n_experts + 1]: global ready flag Region F [4n_experts + 2]: send_buf_ready counter Region G [4*n_experts + 3]: shared_expert_started counter
Region A will be used by combine_async kernel to track the number of tokens of each expert-rank pair. Region D, E, F and G needs to be reset to 0 once the dispatch_wait kernel is done.
Returns:
combine_async_sizeβ
static combine_async_size() -> Int
Returns the size in Int32 elements needed by combine_async kernel.
Must match dispatch_wait_size() since combine_async reuses the same memory region.
Returns:
combine_wait_sizeβ
static combine_wait_size() -> Int
Returns the size in Int32 elements needed by combine_wait kernel.
Returns:
total_sizeβ
static total_size() -> Int
Returns the total size in Int32 elements needed for all counters.
Returns:
get_dispatch_async_ptrβ
get_dispatch_async_ptr(self) -> UnsafePointer[Int32, MutExternalOrigin]
Returns pointer to dispatch_async kernel atomic counters.
Layout: [0, n_experts): reserved counters per expert [n_experts, 2*n_experts): finished counters per expert
Returns:
get_dispatch_wait_ptrβ
get_dispatch_wait_ptr(self) -> UnsafePointer[Int32, MutExternalOrigin]
Returns pointer to dispatch_wait kernel atomic counters.
Returns:
get_combine_async_ptrβ
get_combine_async_ptr(self) -> UnsafePointer[Int32, MutExternalOrigin]
Returns pointer to combine_async kernel atomic counters.
Note: Returns the same pointer as get_dispatch_wait_ptr() because combine_async_kernel reads the offset/count data that dispatch_wait_kernel writes.
Returns:
get_combine_wait_ptrβ
get_combine_wait_ptr(self) -> UnsafePointer[Int32, MutExternalOrigin]
Returns pointer to combine_wait kernel atomic counters.
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!