IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

InferenceSession

InferenceSession​

class max.engine.InferenceSession(devices=(), num_threads=None, *, custom_extensions=None)

source

Bases: object

Manages an inference session in which you can load and run models.

You need an InferenceSession instance to load a model as a Model object. For example:

session = engine.InferenceSession(devices=[CPU()])
model = session.load(model_path)

For workflows that need to separate compilation from weight binding, use compile() followed by init() or init_all(). For example:

session = engine.InferenceSession(devices=[CPU()])
compiled = session.compile(model_path)
model = session.init(compiled)

Parameters:

  • devices (Iterable[Device]) – A list of devices on which to run inference. The host CPU is always included automatically.
  • num_threads (int | None) – The number of execution threads. Defaults to None, which lets the runtime choose automatically.
  • custom_extensions (CustomExtensionsType | None) – The extensions to load for the model. Supports paths to a .mojoc/.mojopkg custom ops library or a .mojo source file.

compile()​

compile(model, *, custom_extensions=None)

source

Compiles a model without binding weights or device memory.

Use this when you want to separate compilation from initialization, for example to populate a compile cache ahead of time, including in cross-compilation scenarios where the target device may not be attached. The returned CompiledModel requires initialization before execution. Pass it to init() or init_all() to produce an executable Model.

Parameters:

  • model (str | Path | Module | Graph) – A Graph instance, a max.graph.Module containing one or more mo.graph ops, or the path to a saved model file (for example, a .mef file).
  • custom_extensions (CustomExtensionsType | None) – The extensions to load for the model. Supports paths to .mojopkg custom ops.

Returns:

A CompiledModel artifact ready to be initialized.

Raises:

RuntimeError – If the path provided is invalid or compilation fails.

Return type:

CompiledModel

debug​

debug: DebugConfig = <max.engine.DebugConfig object>

source

devices​

property devices: list[Device]

source

The devices available to the session, including the host CPU.

enable_per_tensor_fp8_quantize()​

enable_per_tensor_fp8_quantize(mode)

source

Enables per-tensor FP8 quantization.

Parameters:

mode (str) – The enable/disable flag. Accepts "false", "off", "no", or "0" to disable. Any other value enables per-tensor FP8 quantization.

Return type:

None

gpu_profiling()​

gpu_profiling(mode)

source

Enables GPU profiling instrumentation for the session.

Works with NVIDIA Nsight Systems and Nsight Compute. When enabled, the runtime adds CUDA driver calls and NVTX markers that allow profiling tools to correlate GPU kernel executions with host-side code.

For example, to enable detailed profiling for Nsight Systems analysis, call gpu_profiling() before load():

from max.engine import InferenceSession
from max.driver import Accelerator

session = InferenceSession(devices=[Accelerator()])
session.gpu_profiling("detailed")
model = session.load(my_graph)

Then run it with nsys:

nsys profile --trace=cuda,nvtx python example.py

Instead of calling gpu_profiling() in code, you can set the MODULAR_ENABLE_PROFILING environment variable when you call nsys profile:

MODULAR_ENABLE_PROFILING=detailed nsys profile --trace=cuda,nvtx python script.py

Be aware that gpu_profiling() overrides the MODULAR_ENABLE_PROFILING environment variable if also used.

Learn more in GPU profiling with Nsight Systems.

Parameters:

mode (Literal['off', 'on', 'detailed']) –

The profiling mode to set. One of:

  • off: Disable profiling (default).
  • on: Enable basic profiling with NVTX markers for kernel correlation.
  • detailed: Enable detailed profiling with additional Python-level NVTX markers.

Return type:

None

init()​

init(compiled, *, weights_registry=None)

source

Initializes a compiled model with weights for execution.

Use this to complete the second half of a compile()/init() pair when the artifact contains a single model. For artifacts with more than one model, use init_all().

Parameters:

  • compiled (CompiledModel) – The compiled artifact returned by compile().
  • weights_registry (Mapping[str, DLPackArray] | None) – A mapping from model weight names to their values. The values should be DLPack arrays. If an array is a read-only NumPy array, you must ensure that its lifetime extends beyond the lifetime of the model. Although weights_registry is technically optional, you’ll always need to load weights in practice.

Returns:

The initialized Model, ready to execute.

Return type:

Model

init_all()​

init_all(compiled, *, weights_registry=None)

source

Initializes all models in a compiled artifact for execution.

Use this to complete the second half of a compile()/init_all() pair. Returns one Model per top-level graph in the artifact, keyed by sym_name.

Parameters:

Returns:

A mapping from each model’s sym_name to its initialized Model, ready to execute.

Return type:

dict[str, Model]

load()​

load(model, *, custom_extensions=None, weights_registry=None)

source

Loads a trained model and compiles it for inference.

Parameters:

  • model (str | Path | Graph) – A Graph instance, or the path to a saved model file (for example, a .mef file).
  • custom_extensions (CustomExtensionsType | None) – The extensions to load for the model. Supports paths to .mojoc/.mojopkg custom ops.
  • weights_registry (Mapping[str, DLPackArray] | None) – A mapping from model weight names to their values. The values should be DLPack arrays. If an array is a read-only NumPy array, you must ensure that its lifetime extends beyond the lifetime of the model. Although weights_registry is technically optional, you’ll always need to load weights in practice.

Returns:

The loaded model, compiled and ready to execute.

Raises:

RuntimeError – If the path provided is invalid.

Return type:

Model

load_all()​

load_all(model, *, custom_extensions=None, weights_registry=None)

source

Loads multiple models and compiles them for inference.

A compiled .mef artifact may contain more than one model (for example, a vision encoder and a language model compiled together). This method returns one Model per model encoded in the artifact, keyed by the sym_name of the corresponding mo.graph op (preserved through MEF serialization). For single-model artifacts, the returned dict has exactly one entry.

Parameters:

  • model (str | Path | Module | Graph) – A max.graph.Module containing one or more mo.graph ops, the path to a saved multi-model file (for example, a .mef file), or a single Graph.
  • custom_extensions (CustomExtensionsType | None) – The extensions to load for the model. Supports paths to .mojoc/.mojopkg custom ops.
  • weights_registry (Mapping[str, DLPackArray] | None) – A mapping from model weight names to their values. The values should be DLPack arrays. If an array is a read-only NumPy array, you must ensure that its lifetime extends beyond the lifetime of the model. Although weights_registry is technically optional, you’ll always need to load weights in practice.

Returns:

A mapping from each model’s sym_name to its loaded Model, ready to execute.

Raises:

RuntimeError – If the path provided is invalid.

Return type:

dict[str, Model]

set_debug_print_options()​

set_debug_print_options(style=PrintStyle.COMPACT, precision=6, output_directory=None)

source

Sets the debug print options.

Affects debug printing across all model execution using the same InferenceSession. See print().

Tensors saved with BINARY can be loaded using max.driver.Buffer.mmap(), but you’ll have to provide the expected dtype and shape. Tensors saved with BINARY_MAX_CHECKPOINT are saved with the shape and dtype information and can be loaded with max.driver.buffer.load_max_buffer().

Parameters:

  • style (str | PrintStyle) – The print style for tensor values. One of COMPACT, FULL, BINARY, BINARY_MAX_CHECKPOINT, or NONE.
  • precision (int) – The digits of precision in the output, used when style is FULL.
  • output_directory (str | Path | None) – The directory to store output tensors, used when style is BINARY or BINARY_MAX_CHECKPOINT.

Raises:

  • TypeError – If style is not a valid PrintStyle, if precision is not an int when style is FULL, or if output_directory is not a str or Path.
  • ValueError – If output_directory is empty when style is BINARY or BINARY_MAX_CHECKPOINT.

Return type:

None

set_mojo_assert_level()​

set_mojo_assert_level(level)

source

Sets which Mojo asserts are kept in the compiled model.

Parameters:

level (AssertLevel) – The assert level to use. One of AssertLevel.NONE, AssertLevel.WARN, AssertLevel.SAFE, or AssertLevel.ALL.

Return type:

None

set_mojo_log_level()​

set_mojo_log_level(level)

source

Sets the verbosity of Mojo logging in the compiled model.

Parameters:

level (str | LogLevel) – The log level to use, given as a LogLevel member or its name as a string.

Raises:

TypeError – If level is not a valid LogLevel member or name.

Return type:

None

set_split_k_reduction_precision()​

set_split_k_reduction_precision(precision)

source

Sets the accumulation precision for split-k reductions in large matmuls.

Parameters:

precision (str | SplitKReductionPrecision) – The accumulation precision to use, given as a SplitKReductionPrecision member or its name as a string.

Raises:

TypeError – If precision is not a valid SplitKReductionPrecision member or name.

Return type:

None

use_fi_topk_kernel()​

use_fi_topk_kernel(mode)

source

Enables the fused-inference top-k kernel.

Parameters:

mode (str) – The enable/disable flag. Accepts "false", "off", "no", or "0" to disable. Any other value enables the fused-inference top-k kernel.

Return type:

None

use_old_top_k_kernel()​

use_old_top_k_kernel(mode)

source

Falls back to the previous top-k kernel implementation.

By default, the session uses a newer top-k kernel. Use this fallback only if you encounter correctness or performance issues with the default kernel.

Parameters:

mode (str) – The enable/disable flag. Accepts "false", "off", "no", or "0" to disable. Any other value enables the old top-k kernel.

Return type:

None