For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
Struct_ep_init
struct Struct_ep_init
Implemented traitsβ
AnyType,
ImplicitlyDestructible
Methodsβ
executeβ
static def execute[dispatch_dtype: DType, combine_dtype: DType, hidden_size: Int, top_k: Int, n_experts: Int, max_token_per_rank: Int, n_gpus_per_node: Int, n_nodes: Int, dispatch_scale_dtype: DType, dispatch_fmt_str: StringSlice[StaticConstantOrigin], //, target: StringSlice[StaticConstantOrigin]](dev_ptrs: ManagedTensorSlice[Output, static_spec=dev_ptrs.static_spec], my_rank_tensor: ManagedTensorSlice[Output, static_spec=my_rank_tensor.static_spec], atomic_counters_0: ManagedTensorSlice[MutableInput, static_spec=atomic_counters_0.static_spec], atomic_counters_1: ManagedTensorSlice[MutableInput, static_spec=atomic_counters_1.static_spec], context: DeviceContext)
This kernel initializes the vendor library for Expert Parallelism on the current GPU device. It also allocates symmetric memory buffers.
Arguments: dev_ptrs: Output tensor to store device pointers. Shape [2, 3] where: - First dimension: buffer groups (0=dispatch, 1=combine) - Second dimension: buffer types (0=send, 1=recv, 2=recv_count) my_rank_tensor: Output tensor to store current device's rank. atomic_counters_0: Atomic counters for buffer group 0. atomic_counters_1: Atomic counters for buffer group 1. context: GPU device context
Parameters:
- βdispatch_dtype (
DType): DType used during token dispatch to experts. - βcombine_dtype (
DType): DType used when combining expert outputs. - βhidden_size (
Int): Size of the model's hidden dimension. - βtop_k (
Int): Number of experts each token is routed to. - βn_experts (
Int): Total number of experts across all GPUs. - βmax_token_per_rank (
Int): Maximum number of tokens per GPU. - βn_gpus_per_node (
Int): Number of GPUs per node. - βn_nodes (
Int): Number of physical nodes. - βdispatch_scale_dtype (
DType): DType of the dispatch scale. - βdispatch_fmt_str (
StringSlice[StaticConstantOrigin]): String indicating the dispatch format. - βtarget (
StringSlice[StaticConstantOrigin]): Target for this kernel.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!