IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python function

num_heads_for_device

num_heads_for_device()​

max.nn.attention.num_heads_for_device(*, num_heads, device_idx, num_devices)

source

Computes the number of attention heads assigned to a specific device.

Distributes heads across devices, handling cases where the total is not evenly divisible by the number of devices. Earlier devices receive one extra head when there is a remainder.

Parameters:

  • num_heads (int) – Total number of attention heads.
  • device_idx (int) – The index of the current device (0-based).
  • num_devices (int) – Total number of devices.

Returns:

Number of heads assigned to the specified device.

Return type:

int