@compiler.register
The @compiler.register
decorator registers a custom operation for use with the
Graph API. For more information on custom operations, see
Intro to custom ops.
To define a custom operation:
-
Import the
compiler
package. -
Create a struct that implements the
execute()
and (optional)shape()
static methods. -
Register it using the
@compiler.register
decorator.
The following snippet shows the outline of a custom operation:
@compiler.register("add_vectors_custom")
struct AddVectorsCustom:
@staticmethod
fn execute[...](...):
pass
@staticmethod
fn shape(...) raises -> IndexList:
pass
@compiler.register("add_vectors_custom")
struct AddVectorsCustom:
@staticmethod
fn execute[...](...):
pass
@staticmethod
fn shape(...) raises -> IndexList:
pass
The @compiler.register
decorator takes a single arguments, the name of the
custom operation, as a string. This name is used to load the custom op into your
graph.
Output from the execute()
method is usually returned using one or more
destination-passing style (DPS) output tensors. Destination-passing style (DPS)
means that the calling function passes in pre-allocated storage space for the
output value(s). This allows for more efficient memory management. For example,
the graph compiler can optimize memory use by allocating output tensors on the
stack, instead of requiring custom ops to allocate heap storage for return
values.
Destination passing style requires the graph compiler to determine the
dimensions of the output tensor(s) before executing the operation. It uses the
operation's shape()
function to determine the dimensions if they can't be
determined statically.
The following sections describe the execute()
and shape()
functions.
execute()
function
The execute()
function performs the actual work of the custom op. It takes the
following parameter:
target
(StringLiteral
): Indicates the device the operation is running on: currently takes the values"cpu"
or"gpu"
.
Graph output and input tensors are passed to the execute()
function as
instances of
OutputTensor
and
InputTensor
,
respectively. These are both aliases for specific configurations of
ManagedTensorSlice
,
so they both have the same API.
In addition to input and output tensors, the function can take the following arguments:
-
Any arguments of type
Scalar
. -
A single argument of type
DeviceContextPtr
. This opaque pointer is currently required for GPU support.
import compiler
from utils.index import IndexList
from max.tensor import OutputTensor, InputTensor, foreach, ManagedTensorSlice
from runtime.asyncrt import DeviceContextPtr
@compiler.register("add_vectors_custom")
struct AddVectorsCustom:
@staticmethod
fn execute[
# "gpu" or "cpu"
target: StringLiteral,
](
# the first argument is the output
out: OutputTensor,
# starting here is the list of inputs
x: InputTensor[type = out.type, rank = out.rank],
y: InputTensor[type = out.type, rank = out.rank],
# the context is needed for some GPU calls
ctx: DeviceContextPtr,
):
@parameter
@always_inline
fn func[width: Int](idx: IndexList[x.rank]) -> SIMD[x.type, width]:
return x.load[width](idx) + y.load[width](idx)
foreach[func, target=target](out, ctx)
import compiler
from utils.index import IndexList
from max.tensor import OutputTensor, InputTensor, foreach, ManagedTensorSlice
from runtime.asyncrt import DeviceContextPtr
@compiler.register("add_vectors_custom")
struct AddVectorsCustom:
@staticmethod
fn execute[
# "gpu" or "cpu"
target: StringLiteral,
](
# the first argument is the output
out: OutputTensor,
# starting here is the list of inputs
x: InputTensor[type = out.type, rank = out.rank],
y: InputTensor[type = out.type, rank = out.rank],
# the context is needed for some GPU calls
ctx: DeviceContextPtr,
):
@parameter
@always_inline
fn func[width: Int](idx: IndexList[x.rank]) -> SIMD[x.type, width]:
return x.load[width](idx) + y.load[width](idx)
foreach[func, target=target](out, ctx)
shape()
function
The shape()
function returns the dimensions of the output tensor(s).
The shape()
function is required only if the graph compiler can't statically
determine the shape of the output tensor(s), and you don't manually annotate the
output shapes when building a graph.
The function takes the same arguments as the execute()
function, minus the
output tensors and DeviceContextPtr
. It must return an
IndexList
specifying the dimensions of
the output tensor.
For example, if the operation takes two input tensors, and the shape of the
output tensor matches the first input tensor, you could use the following
shape()
function:
@staticmethod
fn shape(
in1: InputTensor,
in2: InputTensor,
) raises -> IndexList[in1.rank]:
return in1.get_runtime_spec().shape
@staticmethod
fn shape(
in1: InputTensor,
in2: InputTensor,
) raises -> IndexList[in1.rank]:
return in1.get_runtime_spec().shape
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!