MAX Engine Python API

Preview

The MAX Engine Python API reference.

MAX Engine is coming in Q1 2024. Sign up for updates.

You can run an inference with our Python API in just a few lines of code:

  1. Create an InferenceSession.

  2. Load a TensorFlow or PyTorch model with InferenceSession.load(), which returns a Model.

  3. Run the model by passing your input to Model.execute(), which returns the output.

That’s it! For more detail, see the Python get started guide.

class modular.engine.InferenceSession(num_threads: Optional[int] = None, device: Optional[str] = None)

Manages an inference session in which you can load and run models.

You need an instance of this to load a model as a Model object. For example:

session = engine.InferenceSession()
model_path = Path('bert-base-uncased')
model = session.load(model_path)
Parameters:

num_threads (Optional[int]) – Number of threads to use for the inference session. This parameter defaults to the number of physical cores on your machine.

load(model_path: Union[str, Path], *options: Union[TensorFlowLoadOptions, CommonLoadOptions], **kwargs) Model

Loads a trained model and compiles it for inference.

Parameters:
  • model_path (Union[str, pathlib.Path]) – Path to a model. May be a TensorFlow model in the SavedModel format or a traceable PyTorch model.

  • *options (Union[TensorFlowLoadOptions, CommonLoadOptions]) – Load options for configuring how the model should be compiled.

Returns:

The loaded model, compiled and ready to execute.

Return type:

Model

Raises:

RuntimeError – If the path provided is invalid.

class modular.engine.Model

A loaded model that you can execute.

Do not instantiate this class directly. Instead, create it with InferenceSession.

execute(*args, **kwargs) Dict[str, ndarray]

Executes the model with the provided input and returns outputs.

For example, if the model has one input tensor named “input”:

input_tensor = np.random.rand(1, 224, 224, 3)
model.execute(input=input_tensor)
Parameters:

kwargs – The input tensors, each specified with the approprite tensor name as a keyword and passed as an ndarray. You can find the model’s tensor names with input_metadata.

Returns:

A dictionary of output tensors, each as an ndarray identified by its tensor name.

Return type:

Dict

Raises:
  • RuntimeError – If the given input tensors’ name and shape don’t match what the model expects.

  • TypeError – If the given input tensors’ dtype cannot be cast to what the model expects.

property input_metadata: List[TensorSpec]

Metadata about the model’s input tensors, as a list of TensorSpec objects.

For example, you can print the input tensor names, shapes, and dtypes:

for tensor in model.input_metadata:
    print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')
property output_metadata: List[TensorSpec]

Metadata about the model’s output tensors, as a list of TensorSpec objects.

For example, you can print the output tensor names, shapes, and dtypes:

for tensor in model.ouput_metadata:
    print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')
class modular.engine.TensorSpec(shape: List[Optional[int]], dtype: DType, name: str)

Defines the properties of a tensor, including its name, shape and data type.

For usage examples, see Model.input_metadata and Model.output_metadata.

property dtype: DType

A tensor data type.

property name: str

A tensor name.

property shape: List[int]

The shape of the tensor as a list of integers.

If a dimension size is unknown/dynamic (such as the batch size), its value is None.

class modular.engine.TensorFlowLoadOptions(exported_name: str = 'serving_default', type: str = 'tf')

Configures how to load TensorFlow models.

exported_name: str = 'serving_default'

The exported name from the TensorFlow model’s signature.

type: str = 'tf'
class modular.engine.CommonLoadOptions(custom_ops_path: str = '')

Common options for how to load models.

custom_ops_path: str = ''

The path from which to load custom ops.

class modular.engine.DType(value)

The tensor data type.

bool = 0
int8 = 1
int16 = 2
int32 = 3
int64 = 4
uint8 = 5
uint16 = 6
uint32 = 7
uint64 = 8
float16 = 9
float32 = 10
float64 = 11