Python class
TensorStruct
TensorStructโ
class max.pipelines.modeling.base.TensorStruct
Bases: object
Base for structured tensor containers used as pipeline inputs/outputs.
Enforces at class-definition time (via __init_subclass__) that
every field annotation is Tensor, Buffer, or
Optional[Tensor | Buffer]. Scalars, numpy arrays, ints,
strings, and other non-tensor types are rejected with a
TypeError when the subclass is defined (i.e. at import time).
No runtime validation overhead โ the frozen dataclass __init__
assigns fields directly with no extra checks on the hot path.
Subclasses should be decorated with @dataclass(frozen=True):
@dataclass(frozen=True)
class MyInputs(TensorStruct):
tokens: Tensor
latents: Tensor
image: Tensor | None = None # optional featureto()โ
to(device)
Transfer all present tensors to device, returning a new instance.
None-valued optional fields are left as None.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!