IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

GreedyReshard

GreedyReshard​

class max.experimental.sharding.GreedyReshard(on_reshard='silent', allow_partial_to_sharded=False)

source

Bases: object

Default per-op picker: enumerate β†’ filter-by-budget β†’ cheapest.

The shipped default. Plug your own callable matching the Solver protocol to override.

Parameters:

  • on_reshard (Literal['silent', 'warn', 'raise'])
  • allow_partial_to_sharded (bool)

allow_partial_to_sharded​

allow_partial_to_sharded: bool = False

source

When True, the picker may resolve a Partial input by reduce_scatter to Sharded(d) when that is the locally cheapest collective. On Megatron-style transformers this tends to land in sequence-parallel + TP: same byte volume as pure TP, but a sharded residual stream, two extra collectives per block, and per-rank symbolic-dim drift (the cost the SP activation-memory win pays for, which only matters for long-sequence training). When False (the default), the picker may not redistribute a Partial input to Sharded: it either keeps the Partial (linear passthrough) or resolves it to Replicated via allreduce, on exactly the mesh axes that are Partial β€” other axes (for example a dp batch shard) are left untouched. A nonlinear consumer such as rms_norm has no Partial row, so its input lands on Replicated: the textbook Megatron pure-TP layout and the right default for inference (fewer collectives per block, a replicated residual stream, no rebind needed in the model). Set True to opt into sequence-parallel discovery.

on_reshard​

on_reshard: Literal['silent', 'warn', 'raise'] = 'silent'

source

Diagnostic policy when the picked action requires a reshard on any input. "silent" (default), "warn", or "raise". Inspected by _local_dispatch after the picker picks.