struct
Tensor
A tensor type designed to extend MAX Engine with custom ops.
Beware that this Tensor
is completely different from the Tensor
type
in the Mojo standard library. Currently, this max.extensibility.Tensor
is designed only for use when building custom ops for MAX Engine.
For example, here's how you can define a custom op with this Tensor
:
from max.extensibility import Tensor, empty_tensor
from max import register
from math import erf, sqrt
@register.op("my_gelu")
fn gelu[type: DType, rank: Int](x: Tensor[type, rank]) -> Tensor[type, rank]:
var output = empty_tensor[type](x.shape)
@always_inline
@parameter
fn func[width: Int](i: StaticIntTuple[rank]) -> SIMD[type, width]:
var tmp = x.simd_load[width](i)
return tmp / 2 * (1 + erf(tmp / sqrt(2)))
output.for_each[func]()
return output^
Then, you must create a Mojo package with this op and load it with your model into MAX Engine. For more information, read about MAX extensibility.
Parameters
- type (
DType
): DType of the underlying data. - static_rank (
Int
): The tensor rank.
Fields
- data (
DTypePointer[type, 0]
): - shape (
StaticIntTuple[static_rank]
): - strides (
StaticIntTuple[static_rank]
):
Implemented traits
AnyType
,
Stringable
Methods
__init__
__init__(inout self: Self, ptr: DTypePointer[type, 0], shape: StaticIntTuple[static_rank])
Constructs a new Tensor
.
You usually should not instantiate a Tensor
directly. Instead use
empty_tensor()
.
Args:
- ptr (
DTypePointer[type, 0]
): A pointer to the tensor data. - shape (
StaticIntTuple[static_rank]
): The shape of the tensor.
__init__(inout self: Self, ptr: DTypePointer[type, 0], shape: StaticIntTuple[static_rank], strides: StaticIntTuple[static_rank])
Constructs a new Tensor
.
You usually should not instantiate a Tensor
directly. Instead use
empty_tensor()
.
Args:
- ptr (
DTypePointer[type, 0]
): A pointer to the tensor data. - shape (
StaticIntTuple[static_rank]
): The shape of the tensor. - strides (
StaticIntTuple[static_rank]
): The stride size for each dimension.
__moveinit__
__moveinit__(inout self: Self, owned existing: Self)
__del__
__del__(owned self: Self)
nelems
nelems(self: Self) -> Int
Gets the number of elements in the tensor.
rank
rank(self: Self) -> Int
Gets the tensor rank.