struct
Tensor
A tensor type designed to extend MAX Engine with custom ops.
Beware that this Tensor
is completely different from the Tensor
type
in the Mojo standard library. Currently, this max.extensibility.Tensor
is designed only for use when building custom ops for MAX Engine.
For example, here's how you can define a custom op with this Tensor
:
from max.extensibility import Tensor, empty_tensor
from max import register
from math import erf, sqrt
@register.op("my_gelu")
fn gelu[type: DType, rank: Int](x: Tensor[type, rank]) -> Tensor[type, rank]:
var output = empty_tensor[type](x.shape)
@always_inline
@parameter
fn func[width: Int](i: StaticIntTuple[rank]) -> SIMD[type, width]:
var tmp = x.simd_load[width](i)
return tmp / 2 * (1 + erf(tmp / sqrt(2)))
output.for_each[func]()
return output^
Then, you must create a Mojo package with this op and load it with your model into MAX Engine. For more information, read about MAX extensibility.
Parameters
- type (
DType
): DType of the underlying data. - static_rank (
Int
): The tensor rank.
Fields
- data (
DTypePointer[type, 0]
): - shape (
StaticIntTuple[static_rank]
): - strides (
StaticIntTuple[static_rank]
):
Implemented traits
AnyType
,
Stringable
Methods
__init__
__init__(inout self: Self, ptr: DTypePointer[type, 0], shape: StaticIntTuple[static_rank])
Constructs a new Tensor
.
You usually should not instantiate a Tensor
directly. Instead use
empty_tensor()
.
Args:
- ptr (
DTypePointer[type, 0]
): A pointer to the tensor data. - shape (
StaticIntTuple[static_rank]
): The shape of the tensor.
__init__(inout self: Self, ptr: DTypePointer[type, 0], shape: StaticIntTuple[static_rank], strides: StaticIntTuple[static_rank])
Constructs a new Tensor
.
You usually should not instantiate a Tensor
directly. Instead use
empty_tensor()
.
Args:
- ptr (
DTypePointer[type, 0]
): A pointer to the tensor data. - shape (
StaticIntTuple[static_rank]
): The shape of the tensor. - strides (
StaticIntTuple[static_rank]
): The stride size for each dimension.
__moveinit__
__moveinit__(inout self: Self, owned existing: Self)
__del__
__del__(owned self: Self)
nelems
nelems(self: Self) -> Int
Gets the number of elements in the tensor.
rank
rank(self: Self) -> Int
Gets the tensor rank.
store
store[width: Int](inout self: Self, index: StaticIntTuple[static_rank], value: SIMD[type, width])
Stores multiple values at the specified indices in the tensor.
Parameters:
- width (
Int
): The SIMD width.
Args:
- index (
StaticIntTuple[static_rank]
): The indices where to store the values. - value (
SIMD[type, width]
): The values to store.
store[width: Int](inout self: Self, index: Int, value: SIMD[type, width])
Stores a single value at the specified index in the tensor.
Constraints:
The tensor's static_rank
must be 1
.
Parameters:
- width (
Int
): The SIMD width.
Args:
- index (
Int
): The index where to store the values. - value (
SIMD[type, width]
): The values to store.
get_nd_indices
get_nd_indices(self: Self) -> StaticIntTuple[static_rank]
Creates empty indices with the same rank as this tensor.
simd_load
simd_load[simd_width: Int](self: Self, index: Int) -> SIMD[type, $0]
Gets the values stored in the tensor at the given index.
Constraints:
The tensor's static_rank
must be 1
.
Parameters:
- simd_width (
Int
): The SIMD width.
Args:
- index (
Int
): The index where the values are stored.
Returns:
The values as a SIMD
.
simd_load[simd_width: Int](self: Self, index: StaticIntTuple[static_rank]) -> SIMD[type, $0]
Gets the values stored in the tensor at the given indices.
Parameters:
- simd_width (
Int
): The SIMD width.
Args:
- index (
StaticIntTuple[static_rank]
): The indices where the values are stored.
Returns:
The values as a SIMD
.
for_each
for_each[func: fn[Int](StaticIntTuple[static_rank]) capturing -> SIMD[type, $0]](inout self: Self)
Executes a lambda for every element in the tensor.
The lambda function must take an Int
parameter for the SIMD width,
and a StaticIntTuple
argument for tensor indices, and then return a
SIMD
with the mutated value at the given indeces.
Parameters:
- func (
fn[Int](StaticIntTuple[static_rank]) capturing -> SIMD[type, $0]
): The lambda function to execute for each tensor indeces.