IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

VarlenSelectiveStateUpdate

struct VarlenSelectiveStateUpdate[dt_softplus: Bool = False]

Varlen selective state update for autoregressive inference.

Performs a single step of the SSM recurrence for incremental token generation with multi-head support.

Tensor Shapes: - state: (batch, nheads, dim, dstate) - SSM state (in/out) - output: (batch, nheads, dim) - Output tensor - x: (batch, nheads, dim) - Input tensor - dt: (batch, nheads, dim) - Time delta tensor - A: (nheads, dim, dstate) - State transition matrix - B: (batch, ngroups, dstate) - Input matrix - C: (batch, ngroups, dstate) - Output matrix - D: (nheads, dim) - Skip connection (optional, can be empty) - z: (batch, nheads, dim) - Gating tensor (optional, can be empty) - dt_bias: (nheads, dim) - Time delta bias (optional, can be empty) - state_batch_indices: (batch,) - Indices into state batch (optional)

Parameters​

  • ​dt_softplus (Bool): If True, applies softplus activation to dt values.

Implemented traits​

AnyType, ImplicitlyDestructible

Methods​

execute​

static execute[dtype: DType, target: StringSlice[StaticConstantOrigin]](state: ManagedTensorSlice[Output, static_spec=state.static_spec], output: ManagedTensorSlice[Output, static_spec=output.static_spec], x: ManagedTensorSlice[Input, static_spec=x.static_spec], dt: ManagedTensorSlice[Input, static_spec=dt.static_spec], A: ManagedTensorSlice[Input, static_spec=A.static_spec], B: ManagedTensorSlice[Input, static_spec=B.static_spec], C: ManagedTensorSlice[Input, static_spec=C.static_spec], D: ManagedTensorSlice[Input, static_spec=D.static_spec], z: ManagedTensorSlice[Input, static_spec=z.static_spec], dt_bias: ManagedTensorSlice[Input, static_spec=dt_bias.static_spec], state_batch_indices: ManagedTensorSlice[Input, static_spec=state_batch_indices.static_spec], ctx: DeviceContext)

shape​

static shape[dtype: DType](x: ManagedTensorSlice[Input, static_spec=x.static_spec], dt: ManagedTensorSlice[Input, static_spec=dt.static_spec], A: ManagedTensorSlice[Input, static_spec=A.static_spec], B: ManagedTensorSlice[Input, static_spec=B.static_spec], C: ManagedTensorSlice[Input, static_spec=C.static_spec], D: ManagedTensorSlice[Input, static_spec=D.static_spec], z: ManagedTensorSlice[Input, static_spec=z.static_spec], dt_bias: ManagedTensorSlice[Input, static_spec=dt_bias.static_spec], state_batch_indices: ManagedTensorSlice[Input, static_spec=state_batch_indices.static_spec]) -> Tuple[IndexList[4], IndexList[3]]

Returns:

Tuple[IndexList[4], IndexList[3]]