Mojo struct
InferenceSession
struct InferenceSession
Holds the context for MAX Engine in which you can load and run models.
For example, you can load a model like this:
var session = engine.InferenceSession()
var model = session.load("bert-base-uncased")
var session = engine.InferenceSession()
var model = session.load("bert-base-uncased")
Implemented traits
AnyType
,
Copyable
,
Movable
,
UnknownDestructibility
Methods
__init__
__init__(out self, options: SessionOptions = SessionOptions())
Creates a new inference session.
Args:
- options (
SessionOptions
): Session options to configure how session is created.
load
load(self, path: Path, *, custom_ops_paths: List[Path] = List(), input_specs: Optional[List[InputSpec]] = Optional(None)) -> Model
Compile and initialize a model in MAX Engine, with the given model path and config.
Note: PyTorch models must be in TorchScript format.
If you're loading a TorchScript model, you must specify the input_specs
argument with a list of
InputSpec
objects
that specify the model's input specs (which may have dynamic shapes).
For details, see how to specify input
specs.
Args:
- path (
Path
): Location of model in filesystem. You may pass a string here because thePath
object supports implicit casting from a string. - custom_ops_paths (
List[Path]
): List of paths to Mojo custom op packages, to replace Modular kernels in models with user-defined kernels. - input_specs (
Optional[List[InputSpec]]
): Provide shapes and dtypes for model inputs. Required for TorchScript models, optional for other input formats.
Returns:
Initialized model ready for inference.
load(self, graph: Graph, *, custom_ops_paths: List[Path] = List(), input_specs: Optional[List[InputSpec]] = Optional(None)) -> Model
Compile and initialize a model in MAX Engine, with the given Graph
and config.
Args:
- graph (
Graph
): MAX Graph. - custom_ops_paths (
List[Path]
): List of paths to Mojo custom op packages, to replace Modular kernels in models with user-defined kernels. - input_specs (
Optional[List[InputSpec]]
): Provide shapes and dtypes for model inputs. Required for TorchScript models, optional for other input formats.
Returns:
Initialized model ready for inference.
get_as_engine_tensor_spec
get_as_engine_tensor_spec(self, name: String, spec: TensorSpec) -> EngineTensorSpec
Gets a TensorSpec compatible with MAX Engine.
Args:
- name (
String
): Name of the Tensor. - spec (
TensorSpec
): Tensor specification in Mojo TensorSpec format.
Returns:
EngineTensorSpec to be used with MAX Engine APIs.
get_as_engine_tensor_spec(self, name: String, shape: Optional[List[Optional[SIMD[int64, 1]]]], dtype: DType) -> EngineTensorSpec
Gets a TensorSpec compatible with MAX Engine.
Args:
- name (
String
): Name of the Tensor. - shape (
Optional[List[Optional[SIMD[int64, 1]]]]
): Shape of the Tensor. Dynamic Dimensions can be represented with None and for Dynamic Rank Tensor use None as value for shape. - dtype (
DType
): DataType of the Tensor.
Returns:
EngineTensorSpec to be used with MAX Engine APIs.
new_tensor_map
new_tensor_map(self) -> TensorMap
Gets a new TensorMap. This can be used to pass inputs to model.
Returns:
A new instance of TensorMap.
new_borrowed_tensor_value
new_borrowed_tensor_value[type: DType](self, tensor: Tensor[type]) -> Value
Create a new Value representing data read-only from given tensor.
The user must ensure the tensor stays live through the lifetime of the value.
Parameters:
- type (
DType
): Data type of the tensor to turn into a Value.
Args:
- tensor (
Tensor[type]
): Tensor to borrow into a value.
Returns:
A value borrowing the tensor.
new_bool_value
new_bool_value(self, value: Bool) -> Value
Create a new Value representing a Bool.
Args:
- value (
Bool
): Boolean to wrap into a value.
Returns:
Value representing the given boolean.
new_list_value
new_list_value(self) -> Value
Create a new Value representing an empty list.
Returns:
A new value containing an empty list.
set_debug_print_options
set_debug_print_options(mut self, style: PrintStyle = PrintStyle(SIMD(0)), precision: UInt = 6, output_directory: String = String(""))
Sets the debug print options on the context.
This affects debug printing across all model execution using the same InferenceSession.
Warning: Even with style set to NONE
, debug print ops in the graph can stop optimizations.
If you see performance issues, try fully removing debug print ops.
Args:
- style (
PrintStyle
): How the values will be printed. - precision (
UInt
): If the style isFULL
, the digits of precision in the output. - output_directory (
String
): If the style isBINARY
, the directory to store output tensors.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!