AI Engine Python demo

A Jupyter notebook that runs a TensorFlow BERT model and a PyTorch DLRM model.

MAX Engine is coming in Q1 2024. Sign up for updates.

The Modular AI Engine is the world’s fastest unified inference engine, designed to run any TensorFlow or PyTorch model on any hardware backend.

This page is built from the same Jupyter notebook that Nick Kreeger presented in our launch keynote video, in which he shows how easy it is to load trained models from TensorFlow and PyTorch and run them with our Python API on a variety of CPU backends. We’re sharing this executed version of the notebook, just so you can look closely at the code from the video. This notebook is out of date with our Python API, and it will not be maintained. It still reflects the code we showed in the launch keynote video from May 2023.

Below, you can see how we load one TensorFlow model (a BERT model) and one PyTorch model (a DLRM model) into the Modular AI Engine, and then print some model metadata and execute each one.

If you’d like to see performance benchmarks with more real-world models, see our performance dashboard.

Notebook code

import numpy as np
from pathlib import Path
from bert_utils import convert_to_tokens, convert_to_string

TensorFlow BERT-Base Model

tf_bert_model = Path("models/tensorflow/bert")

PyTorch DLRM Recommender Model

pt_dlrm_model = Path("models/pytorch/")

Virtual Machine Information in AWS

Check the concrete machine configuration.

print("="*40, "Processor Information", "="*40, "\n")
!lscpu | grep "Model name"
!lscpu | grep Architecture
======================================== Processor Information ======================================== 

Model name:                      Neoverse-N1
Architecture:                    aarch64

Import the Modular Python API

By default, the inference engine is small and very low dependency. It will automatically load TensorFlow and PyTorch dependencies when needed.

from modular import engine
session = engine.InferenceSession()

Load and initialize both the TensorFlow and PyTorch models

This process handles loading all framework dependencies for you; the models are ready for inference once loaded.

tf_bert_session = session.load(tf_bert_model)
pt_dlrm_session = session.load(pt_dlrm_model)

Run inference on both the TensorFlow BERT and PyTorch DLRM Models.

The Modular Python API works great with other libraries such as numpy to enable easy input to models.

# Run BERT TensorFlow model with a given question.
question = "When did Copenhagen become the capital of Denmark?"
attention_mask, input_ids, token_type_ids = convert_to_tokens(question)
bert_outputs = tf_bert_session.execute(attention_mask, input_ids, token_type_ids)

# Perform DLRM PyTorch model with random one-hot row of suggested items and features.
recommended_items = np.random.rand(4, 8, 100).astype(np.int32)
dense_features = np.random.rand(4, 256).astype(np.float32)
dlrm_outputs = pt_dlrm_session.execute(dense_features, recommended_items)

Inspecting the output of BERT

The Modular Python API provides access to shapes, dtypes, and tensor output values. This example takes the outputs from BERT and converts the output tokens to strings.

print("Number of output tensors:", len(bert_outputs))
print(bert_outputs[0].shape, bert_outputs[0].dtype)
print(bert_outputs[1].shape, bert_outputs[1].dtype)

print("Answer:", convert_to_string(input_ids, bert_outputs))
Number of output tensors: 2
(1, 192) float32
(1, 192) float32
Answer: Copenhagen became the capital of Denmark in the early 15th century

Inspecting the output of DLRM

As with the example above, the PyTorch DLRM model output has the same API for accessing inference results.

print("Number of output tensors:", len(dlrm_outputs))
print(dlrm_outputs[0].shape, dlrm_outputs[0].dtype)

dlrm_suggested_items = ["dog", "cat", "rabbit", "snake"]
dlrm_recommended_index = dlrm_outputs[0].argmax()
print("Recommend item index:", dlrm_outputs[0].argmax())
print("Recommend item:", dlrm_suggested_items[dlrm_recommended_index])
Number of output tensors: 1
(4, 1) float32
Recommend item index: 1
Recommend item: cat