Skip to main content

Python class

TensorStruct

TensorStructโ€‹

class max.pipelines.modeling.base.TensorStruct

source

Bases: object

Base for structured tensor containers used as pipeline inputs/outputs.

Enforces at class-definition time (via __init_subclass__) that every field annotation is Tensor, Buffer, or Optional[Tensor | Buffer]. Scalars, numpy arrays, ints, strings, and other non-tensor types are rejected with a TypeError when the subclass is defined (i.e. at import time).

No runtime validation overhead โ€“ the frozen dataclass __init__ assigns fields directly with no extra checks on the hot path.

Subclasses should be decorated with @dataclass(frozen=True):

@dataclass(frozen=True)
class MyInputs(TensorStruct):
    tokens: Tensor
    latents: Tensor
    image: Tensor | None = None  # optional feature

to()โ€‹

to(device)

source

Transfer all present tensors to device, returning a new instance.

None-valued optional fields are left as None.

Parameters:

device (Device)

Return type:

Self