Skip to main content

Python function

num_heads_for_device

num_heads_for_device()​

max.nn.attention.num_heads_for_device(*, num_heads, device_idx, num_devices)

source

Computes the number of attention heads assigned to a specific device.

Distributes heads across devices, handling cases where the total is not evenly divisible by the number of devices. Earlier devices receive one extra head when there is a remainder.

Parameters:

  • num_heads (int) – Total number of attention heads.
  • device_idx (int) – The index of the current device (0-based).
  • num_devices (int) – Total number of devices.

Returns:

Number of heads assigned to the specified device.

Return type:

int