IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

TaylorSeerCache

TaylorSeerCache

class max.pipelines.diffusion.TaylorSeerCache(config, dtype, device, session)

source

Bases: object

High-level TaylorSeer for executor pipelines (Buffer-based).

Compiles predict and update graphs through the executor’s shared InferenceSession at construction time. All runtime methods accept and return Buffer objects, matching the executor’s driver-level API.

Parameters:

  • config (DenoisingCacheConfig) – Denoising cache configuration (must have taylorseer=True and resolved non-None fields for interval/warmup/order).
  • dtype (DType) – Model compute dtype (e.g. DType.bfloat16).
  • device (Device) – Target device for graph execution.
  • session (InferenceSession) – The executor’s shared inference session.

create_state()

create_state(batch_size, seq_len, output_dim)

source

Allocate fresh per-request TaylorSeer state buffers.

Parameters:

  • batch_size (int) – Batch dimension.
  • seq_len (int) – Sequence length (packed latent tokens).
  • output_dim (int) – Channel dimension of noise_pred.

Returns:

A new TaylorSeerBufferState with zero-initialized factor buffers on the target device.

Return type:

TaylorSeerBufferState

predict()

predict(state, step)

source

Predict noise_pred from cached Taylor factors.

Parameters:

Returns:

Predicted noise_pred buffer, shape (B, seq, C).

Return type:

Buffer

should_skip()

should_skip(step)

source

Return True when the full transformer pass can be skipped.

Parameters:

step (int)

Return type:

bool

update()

update(state, noise_pred, step)

source

Update Taylor factors from a full transformer computation.

Mutates state in-place with new factor values.

Parameters:

  • state (TaylorSeerBufferState) – Current per-request TaylorSeer state.
  • noise_pred (Buffer) – Fresh noise_pred from the transformer, shape (B, seq, C).
  • step (int) – Current denoising step index.

Return type:

None