IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

Struct_ep_init

struct Struct_ep_init

Implemented traits​

AnyType, ImplicitlyDestructible

Methods​

execute​

static def execute[dispatch_dtype: DType, combine_dtype: DType, hidden_size: Int, top_k: Int, n_experts: Int, max_token_per_rank: Int, n_gpus_per_node: Int, n_nodes: Int, dispatch_scale_dtype: DType, dispatch_fmt_str: StringSlice[StaticConstantOrigin], //, target: StringSlice[StaticConstantOrigin]](dev_ptrs: ManagedTensorSlice[Output, static_spec=dev_ptrs.static_spec], my_rank_tensor: ManagedTensorSlice[Output, static_spec=my_rank_tensor.static_spec], atomic_counters_0: ManagedTensorSlice[MutableInput, static_spec=atomic_counters_0.static_spec], atomic_counters_1: ManagedTensorSlice[MutableInput, static_spec=atomic_counters_1.static_spec], context: DeviceContext)

This kernel initializes the vendor library for Expert Parallelism on the current GPU device. It also allocates symmetric memory buffers.

Arguments: dev_ptrs: Output tensor to store device pointers. Shape [2, 3] where: - First dimension: buffer groups (0=dispatch, 1=combine) - Second dimension: buffer types (0=send, 1=recv, 2=recv_count) my_rank_tensor: Output tensor to store current device's rank. atomic_counters_0: Atomic counters for buffer group 0. atomic_counters_1: Atomic counters for buffer group 1. context: GPU device context

Parameters:

  • ​dispatch_dtype (DType): DType used during token dispatch to experts.
  • ​combine_dtype (DType): DType used when combining expert outputs.
  • ​hidden_size (Int): Size of the model's hidden dimension.
  • ​top_k (Int): Number of experts each token is routed to.
  • ​n_experts (Int): Total number of experts across all GPUs.
  • ​max_token_per_rank (Int): Maximum number of tokens per GPU.
  • ​n_gpus_per_node (Int): Number of GPUs per node.
  • ​n_nodes (Int): Number of physical nodes.
  • ​dispatch_scale_dtype (DType): DType of the dispatch scale.
  • ​dispatch_fmt_str (StringSlice[StaticConstantOrigin]): String indicating the dispatch format.
  • ​target (StringSlice[StaticConstantOrigin]): Target for this kernel.