For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Python class
GreedyReshard
GreedyReshardβ
class max.experimental.sharding.GreedyReshard(on_reshard='silent', allow_partial_to_sharded=False)
Bases: object
Default per-op picker: enumerate β filter-by-budget β cheapest.
The shipped default. Plug your own callable matching the
Solver protocol to override.
allow_partial_to_shardedβ
allow_partial_to_sharded: bool = False
When True, the picker may resolve a Partial input by
reduce_scatter to Sharded(d) when that is the locally cheapest
collective. On Megatron-style transformers this tends to land in
sequence-parallel + TP: same byte volume as pure TP, but a sharded
residual stream, two extra collectives per block, and per-rank
symbolic-dim drift (the cost the SP activation-memory win pays for,
which only matters for long-sequence training). When False (the
default), the picker may not redistribute a Partial input to
Sharded: it either keeps the Partial (linear passthrough) or
resolves it to Replicated via allreduce, on exactly the mesh axes
that are Partial β other axes (for example a dp batch shard)
are left untouched. A nonlinear consumer such as rms_norm has no
Partial row, so its input lands on Replicated: the textbook
Megatron pure-TP layout and the right default for inference (fewer
collectives per block, a replicated residual stream, no rebind
needed in the model). Set True to opt into sequence-parallel
discovery.
on_reshardβ
on_reshard: Literal['silent', 'warn', 'raise'] = 'silent'
Diagnostic policy when the picked action requires a reshard on any
input. "silent" (default), "warn", or "raise". Inspected by
_local_dispatch after the picker picks.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!