Skip to main content

Mojo struct

ReduceScatterConfig

struct ReduceScatterConfig[dtype: DType, ngpus: Int, simd_width: Int = simd_width_of[dtype, get_gpu_target()](), alignment: Int = align_of[SIMD[dtype, simd_width]](), accum_type: DType = get_accum_type[dtype]()]

Configuration for axis-aware reduce-scatter partitioning.

Divides axis_size units evenly across GPUs. Lower ranks get one extra unit when there's a remainder. The 1D case is a special case where axis_size = num_elements // simd_width and unit_numel = simd_width.

Fields

  • stride (Int):
  • axis_part (Int):
  • axis_remainder (Int):
  • unit_numel (Int):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

Methods

__init__

__init__(axis_size: Int, unit_numel: Int, threads_per_gpu: Int) -> Self

General constructor for axis-aware partitioning.

Args:

  • axis_size (Int): Number of units along the scatter axis.
  • unit_numel (Int): Number of elements per unit.
  • threads_per_gpu (Int): Total threads per GPU.

__init__(num_elements: Int, threads_per_gpu: Int) -> Self

1D convenience constructor. Partitions by SIMD vectors.

rank_unit_start

rank_unit_start(self, rank: Int) -> Int

Start unit index along scatter axis for this rank.

Returns:

Int

rank_units

rank_units(self, rank: Int) -> Int

Number of units for this rank.

Returns:

Int

rank_num_elements

rank_num_elements(self, rank: Int) -> Int

Total elements for this rank.

Returns:

Int

rank_start

rank_start(self, rank: Int) -> Int

Flat element start offset for this rank.

Returns:

Int

rank_end

rank_end(self, rank: Int) -> Int

Flat element end offset for this rank.

Returns:

Int

rank_part

rank_part(self, rank: Int) -> Int

Number of elements for this rank (alias for rank_num_elements).

Returns:

Int

thr_local_start

thr_local_start(self, thread_idx: UInt) -> Int

Returns:

Int

Was this page helpful?