IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

ReduceScatterConfig

struct ReduceScatterConfig[dtype: DType, ngpus: Int, simd_width: Int = simd_width_of[dtype, get_gpu_target()](), alignment: Int = align_of[SIMD[dtype, simd_width]](), accum_type: DType = get_accum_type[dtype]()]

Configuration for axis-aware reduce-scatter partitioning.

Divides axis_size units evenly across GPUs. Lower ranks get one extra unit when there's a remainder. The 1D case is a special case where axis_size = num_elements // simd_width and unit_numel = simd_width.

Fields​

  • ​stride (Int):
  • ​axis_part (Int):
  • ​axis_remainder (Int):
  • ​unit_numel (Int):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

Methods​

__init__​

def __init__(axis_size: Int, unit_numel: Int, threads_per_gpu: Int) -> Self

General constructor for axis-aware partitioning.

Args:

  • ​axis_size (Int): Number of units along the scatter axis.
  • ​unit_numel (Int): Number of elements per unit.
  • ​threads_per_gpu (Int): Total threads per GPU.

def __init__(num_elements: Int, threads_per_gpu: Int) -> Self

1D convenience constructor. Partitions by SIMD vectors.

rank_unit_start​

def rank_unit_start(self, rank: Int) -> Int

Start unit index along scatter axis for this rank.

Returns:

Int

rank_units​

def rank_units(self, rank: Int) -> Int

Number of units for this rank.

Returns:

Int

rank_num_elements​

def rank_num_elements(self, rank: Int) -> Int

Total elements for this rank.

Returns:

Int

rank_start​

def rank_start(self, rank: Int) -> Int

Flat element start offset for this rank.

Returns:

Int

rank_end​

def rank_end(self, rank: Int) -> Int

Flat element end offset for this rank.

Returns:

Int

rank_part​

def rank_part(self, rank: Int) -> Int

Number of elements for this rank (alias for rank_num_elements).

Returns:

Int

thr_local_start​

def thr_local_start(self, thread_idx: Int) -> Int

Returns:

Int