Skip to main content

Python package

data_parallelism

Module for data parallelism utilities.

split_batch()

max.nn.data_parallelism.split_batch(devices, input, input_row_offsets, data_parallel_splits)

Split a ragged input batch into data parallel batches.

See split_batch_replicated for a version of this method that takes replicated inputs and input_row_offsets for each device.

Example:

devices = [device_1, device_2] input = [seq_1, seq_2, seq_3, seq_4] input_row_offsets = [0, offset_1, offset_2, offset_3, offset_4] data_parallel_splits = [0, 2, 4]

Outputs:
split_input = [seq_1, seq_2], [seq_3, seq_4] split_offsets = [0, offset_1, offset_2], [0, new_offset_3, new_offset_4]

After being split, the outputs will be placed on the devices specified in devices.

The size of data_parallel_splits must be equal to the number of devices + 1.

Parameters:

  • input (TensorValue) – Input tensor of shape [total_seq_len, …].
  • input_row_offsets (TensorValue) – Row offsets tensor indicating batch boundaries.
  • data_parallel_splits (TensorValue) – Buffer containing batch splits for each device. Must be located on CPU.
  • devices (list[DeviceRef])

Returns:

Tuple of (split_input, split_offsets) where split_input and split_offsets are lists of tensors, one per device

Return type:

tuple[list[TensorValue], list[TensorValue]]

split_batch_replicated()

max.nn.data_parallelism.split_batch_replicated(devices, input, input_row_offsets, input_row_offsets_int64, data_parallel_splits)

Split a ragged token batch into data parallel batches.

This version takes a list of input and input_row_offsets replicated on each device. Also see split_input for a version of this method that takes a single ragged token batch.

Example:

devices = [device_1, device_2] input = [seq_1, seq_2, seq_3, seq_4] (replicated for each device) input_row_offsets = [0, offset_1, offset_2, offset_3, offset_4] (replicated for each device) data_parallel_splits = [0, 2, 4]

Outputs:
split_input = [seq_1, seq_2], [seq_3, seq_4] split_offsets = [0, offset_1, offset_2], [0, new_offset_3, new_offset_4]

After being split, the outputs will be placed on the devices specified in devices.

The size of data_parallel_splits must be equal to the number of devices + 1.

Parameters:

  • devices (list[DeviceRef]) – List of devices to split the input on.
  • input (list[TensorValue]) – List of input token tensors of shape [total_seq_len]. The list must be the same length as the number of devices.
  • input_row_offsets (list[TensorValue]) – Row offsets tensor indicating batch boundaries. The list must be the same length as the number of devices.
  • input_row_offsets_int64 (TensorValue) – Row offsets tensor indicating batch boundaries. Must be located on CPU.
  • data_parallel_splits (TensorValue) – Buffer containing batch splits for each device. Must be located on CPU.

Returns:

Tuple of (split_input, split_offsets) where split_input and split_offsets are lists of tensors, one per device

Return type:

tuple[list[TensorValue], list[TensorValue]]

Was this page helpful?