Skip to main content

Deploy with Triton

NVIDIA’s Triton Inference Server provides a comprehensive inference serving platform with features such as model management, auto scaling, dynamic batching, and more. If you're already using Triton, then MAX Engine works as a drop-in replacement for your existing inference backend.

Our Triton backend provides full compatibility with the Triton Inference Server:

  • Server-side model configurations work as-is; just change the backend name.
  • Client-side code that sends inference requests works as-is.

Example configuration

For example, below is a model configuration file for the Triton Inference Server using MAX Engine as the compute backend. The only difference, compared to a TensorFlow or PyTorch backend configuration, is the backend name:

bert-config.pbtxt
default_model_filename: "bert-base.savedmodel"
backend: "max"
input {
name: "attention_mask"
data_type: TYPE_INT32
dims: [-1, -1]
}
input {
name: "input_ids"
data_type: TYPE_INT32
dims: [-1, -1]
}
input {
name: "token_type_ids"
data_type: TYPE_INT32
dims: [-1, -1]
}
output {
name: "end_logits"
data_type: TYPE_FP32
dims: [-1, -1]
}
output {
name: "start_logits"
data_type: TYPE_FP32
dims: [-1, -1]
}
instance_group {
kind: KIND_CPU
}

Example client request

For client programs that send requests to the server, you don’t need to change the code at all.

Here is some client code that sends BERT Q\&A inference requests to the above Triton Inference Server configuration, which is using the Modular Inference Engine as its compute backend (you’d never know that it’s using Modular by looking at this code because it’s unaffected):

bert-client.py
from transformers import AutoTokenizer
import numpy as np
import tritonclient.grpc as grpcclient

def answer_question(
triton_client, question, context, timeout=None
):
tokenizer = AutoTokenizer.from_pretrained(
"bert-large-uncased-whole-word-masking-finetuned-squad"
)

# Convert the inputs to bert tokens.
inputs = tokenizer(
question, context, add_special_tokens=True, return_tensors="tf"
)

sequence_length = inputs["input_ids"].shape[1]

# Set the http outputs.
grpc_tensors = [
grpcclient.InferInput("attention_mask", (1, sequence_length), "INT32"),
grpcclient.InferInput("input_ids", (1, sequence_length), "INT32"),
grpcclient.InferInput("token_type_ids", (1, sequence_length), "INT32"),
]

# Tokenized input tensors -> triton.
grpc_tensors[0].set_data_from_numpy(inputs["attention_mask"].numpy())
grpc_tensors[1].set_data_from_numpy(inputs["input_ids"].numpy())
grpc_tensors[2].set_data_from_numpy(inputs["token_type_ids"].numpy())

# Get the result from the server.
result = triton_client.infer("bert-large", grpc_tensors, timeout=timeout)

# Reshape back to `sequence_length`
server_start = result.as_numpy(f"start_logits")[:, :sequence_length]
server_end = result.as_numpy(f"end_logits")[:, :sequence_length]

# Use numpy to get the predicted start and end position from the
# output softmax scores.
predicted_start = np.argmax(server_start, axis=1)[0]
predicted_end = np.argmax(server_end, axis=1)[0] + 1

# The answer is expressed in terms of positions in the input, so we need
# this to be able to map back to the answer text
input_ids = inputs["input_ids"].numpy()[0]

# Use above positions to find the answer in the input.
answer_tokens = tokenizer.convert_ids_to_tokens(
input_ids[predicted_start:predicted_end]
)

# Convert it into human readable string,
answer = tokenizer.convert_tokens_to_string(answer_tokens)
return answer


def main(context_filename, host, port):
with open(context_filename) as f:
context = f.read()
print("Context:\n", context)

# Open the triton server connection.
url = f"{host}:{port}"
triton_client = grpcclient.InferenceServerClient(url=url)

while True:
question = input("> ")
output = answer_question(triton_client, question, context)
print(output)

# Close server connection.
triton_client.close()


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(prog="bert-cli")
parser.add_argument("-c", "--context", required=True, help="Context file")
parser.add_argument(
"-s", "--server", required=True, help="Inference server host"
)
parser.add_argument(
"-p",
"--port",
required=False,
default="8001",
help="Inference server port",
)
args = parser.parse_args()
main(args.context, args.server, args.port)

The above client code is no different from code you can use with a Triton Inference Server instance that’s running a different backend. So it’s easy to just update the model config and be done.

MAX Serving is coming in Q1 2024. Sign up for updates.