IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

EPLocalSyncCounters

struct EPLocalSyncCounters[n_experts: Int]

Manages atomic counters for EP kernel synchronization within a device.

This struct provides dedicated atomic counter space for each of the four EP kernels: dispatch_async, dispatch_wait, combine_async, and combine_wait. Each kernel has its own memory region to avoid conflicts, except dispatch_wait and combine_async which must share memory since combine_async reads data that dispatch_wait writes.

The struct is used to synchronize between thread blocks within the same device.

Memory Layout (all sizes in Int32 elements):

  • dispatch_async: 2 * n_experts + MAX_GPUS_PER_NODE
  • dispatch_wait/combine_async: 4 * n_experts + 4
  • combine_wait: 2 * n_experts

Fields​

  • ​ptr (UnsafePointer[Int32, MutExternalOrigin]): Base pointer to the allocated atomic counter memory.

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = EPLocalSyncCounters[n_experts]

Methods​

__init__​

__init__(ptr: UnsafePointer[Int32, address_space=ptr.address_space]) -> Self

__init__(buffer: DeviceBuffer[DType.int32]) -> Self

get_type_name​

static get_type_name() -> String

Returns:

String

dispatch_async_size​

static dispatch_async_size() -> Int

Returns the size in Int32 elements needed by dispatch_async kernel.

Returns:

Int

dispatch_wait_size​

static dispatch_wait_size() -> Int

Returns the size in Int32 elements needed by dispatch_wait kernel.

Layout (see EPDispatchKernel for exact offset constants): Region A [0, 2n_experts): per expert-rank combine_async compat data Region B [2n_experts, 3n_experts): within-expert rank prefix sums Region C [3n_experts, 4n_experts): per-expert work counters (only first n_local_experts entries used; rest unused) Region D [4n_experts]: cleanup ref counter Region E [4n_experts + 1]: global ready flag Region F [4n_experts + 2]: send_buf_ready counter Region G [4*n_experts + 3]: shared_expert_started counter

Region A will be used by combine_async kernel to track the number of tokens of each expert-rank pair. Region D, E, F and G needs to be reset to 0 once the dispatch_wait kernel is done.

Returns:

Int

combine_async_size​

static combine_async_size() -> Int

Returns the size in Int32 elements needed by combine_async kernel.

Must match dispatch_wait_size() since combine_async reuses the same memory region.

Returns:

Int

combine_wait_size​

static combine_wait_size() -> Int

Returns the size in Int32 elements needed by combine_wait kernel.

Returns:

Int

total_size​

static total_size() -> Int

Returns the total size in Int32 elements needed for all counters.

Returns:

Int

get_dispatch_async_ptr​

get_dispatch_async_ptr(self) -> UnsafePointer[Int32, MutExternalOrigin]

Returns pointer to dispatch_async kernel atomic counters.

Layout: [0, n_experts): reserved counters per expert [n_experts, 2*n_experts): finished counters per expert

Returns:

UnsafePointer[Int32, MutExternalOrigin]

get_dispatch_wait_ptr​

get_dispatch_wait_ptr(self) -> UnsafePointer[Int32, MutExternalOrigin]

Returns pointer to dispatch_wait kernel atomic counters.

Returns:

UnsafePointer[Int32, MutExternalOrigin]

get_combine_async_ptr​

get_combine_async_ptr(self) -> UnsafePointer[Int32, MutExternalOrigin]

Returns pointer to combine_async kernel atomic counters.

Note: Returns the same pointer as get_dispatch_wait_ptr() because combine_async_kernel reads the offset/count data that dispatch_wait_kernel writes.

Returns:

UnsafePointer[Int32, MutExternalOrigin]

get_combine_wait_ptr​

get_combine_wait_ptr(self) -> UnsafePointer[Int32, MutExternalOrigin]

Returns pointer to combine_wait kernel atomic counters.

Returns:

UnsafePointer[Int32, MutExternalOrigin]