Python function
split_batch_replicated
split_batch_replicated()
max.nn.split_batch_replicated(devices, input, input_row_offsets, input_row_offsets_int64, data_parallel_splits, prefix='')
Split a ragged token batch into data parallel batches.
This version takes a list of input and input_row_offsets replicated on each device. Also see split_input for a version of this method that takes a single ragged token batch.
devices = [device_1, device_2]
input = [seq_1, seq_2, seq_3, seq_4] (replicated for each device)
input_row_offsets = [0, offset_1, offset_2, offset_3, offset_4] (replicated for each device)
data_parallel_splits = [0, 2, 4]
# Outputs
split_input = [seq_1, seq_2], [seq_3, seq_4]
split_offsets = [0, offset_1, offset_2], [0, new_offset_3, new_offset_4]This method places the outputs on the devices specified in devices.
-
Parameters:
-
- devices (list[DeviceRef]) – List of devices to split the input on.
- input (list[TensorValue]) – List of input token tensors of shape [total_seq_len]. The list must be the same length as the number of devices.
- input_row_offsets (list[TensorValue]) – Row offsets tensor indicating batch boundaries. The list must be the same length as the number of devices.
- input_row_offsets_int64 (TensorValue) – Row offsets tensor indicating batch boundaries. Must be located on CPU.
- data_parallel_splits (TensorValue) – Buffer containing batch splits for each device
that must be located on CPU. The size of
data_parallel_splitsmust be equal to the number of devices + 1. - prefix (str)
-
Returns:
-
Tuple of (split_input, split_offsets) where split_input and split_offsets are lists of tensors, one per device.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!