Skip to main content

Mojo struct

ReduceScatterConfig

struct ReduceScatterConfig[dtype: DType, ngpus: Int, simd_width: Int = simd_width_of[dtype, get_gpu_target()](), alignment: Int = align_of[SIMD[dtype, simd_width]](), accum_type: DType = get_accum_type[dtype]()]

Configuration for axis-aware reduce-scatter partitioning.

Divides axis_size units evenly across GPUs. Lower ranks get one extra unit when there's a remainder. The 1D case is a special case where axis_size = num_elements // simd_width and unit_numel = simd_width.

Fields​

  • ​stride (Int):
  • ​axis_part (Int):
  • ​axis_remainder (Int):
  • ​unit_numel (Int):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

Methods​

__init__​

__init__(axis_size: Int, unit_numel: Int, threads_per_gpu: Int) -> Self

General constructor for axis-aware partitioning.

Args:

  • ​axis_size (Int): Number of units along the scatter axis.
  • ​unit_numel (Int): Number of elements per unit.
  • ​threads_per_gpu (Int): Total threads per GPU.

__init__(num_elements: Int, threads_per_gpu: Int) -> Self

1D convenience constructor. Partitions by SIMD vectors.

rank_unit_start​

rank_unit_start(self, rank: Int) -> Int

Start unit index along scatter axis for this rank.

Returns:

Int

rank_units​

rank_units(self, rank: Int) -> Int

Number of units for this rank.

Returns:

Int

rank_num_elements​

rank_num_elements(self, rank: Int) -> Int

Total elements for this rank.

Returns:

Int

rank_start​

rank_start(self, rank: Int) -> Int

Flat element start offset for this rank.

Returns:

Int

rank_end​

rank_end(self, rank: Int) -> Int

Flat element end offset for this rank.

Returns:

Int

rank_part​

rank_part(self, rank: Int) -> Int

Number of elements for this rank (alias for rank_num_elements).

Returns:

Int

thr_local_start​

thr_local_start(self, thread_idx: Int) -> Int

Returns:

Int