For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
VarlenSelectiveStateUpdate
struct VarlenSelectiveStateUpdate[dt_softplus: Bool = False]
Varlen selective state update for autoregressive inference.
Performs a single step of the SSM recurrence for incremental token generation with multi-head support.
Tensor Shapes: - state: (batch, nheads, dim, dstate) - SSM state (in/out) - output: (batch, nheads, dim) - Output tensor - x: (batch, nheads, dim) - Input tensor - dt: (batch, nheads, dim) - Time delta tensor - A: (nheads, dim, dstate) - State transition matrix - B: (batch, ngroups, dstate) - Input matrix - C: (batch, ngroups, dstate) - Output matrix - D: (nheads, dim) - Skip connection (optional, can be empty) - z: (batch, nheads, dim) - Gating tensor (optional, can be empty) - dt_bias: (nheads, dim) - Time delta bias (optional, can be empty) - state_batch_indices: (batch,) - Indices into state batch (optional)
Parametersβ
- βdt_softplus (
Bool): If True, applies softplus activation to dt values.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
Methodsβ
executeβ
static execute[dtype: DType, target: StringSlice[StaticConstantOrigin]](state: ManagedTensorSlice[Output, static_spec=state.static_spec], output: ManagedTensorSlice[Output, static_spec=output.static_spec], x: ManagedTensorSlice[Input, static_spec=x.static_spec], dt: ManagedTensorSlice[Input, static_spec=dt.static_spec], A: ManagedTensorSlice[Input, static_spec=A.static_spec], B: ManagedTensorSlice[Input, static_spec=B.static_spec], C: ManagedTensorSlice[Input, static_spec=C.static_spec], D: ManagedTensorSlice[Input, static_spec=D.static_spec], z: ManagedTensorSlice[Input, static_spec=z.static_spec], dt_bias: ManagedTensorSlice[Input, static_spec=dt_bias.static_spec], state_batch_indices: ManagedTensorSlice[Input, static_spec=state_batch_indices.static_spec], ctx: DeviceContext)
shapeβ
static shape[dtype: DType](x: ManagedTensorSlice[Input, static_spec=x.static_spec], dt: ManagedTensorSlice[Input, static_spec=dt.static_spec], A: ManagedTensorSlice[Input, static_spec=A.static_spec], B: ManagedTensorSlice[Input, static_spec=B.static_spec], C: ManagedTensorSlice[Input, static_spec=C.static_spec], D: ManagedTensorSlice[Input, static_spec=D.static_spec], z: ManagedTensorSlice[Input, static_spec=z.static_spec], dt_bias: ManagedTensorSlice[Input, static_spec=dt_bias.static_spec], state_batch_indices: ManagedTensorSlice[Input, static_spec=state_batch_indices.static_spec]) -> Tuple[IndexList[4], IndexList[3]]
Returns:
Tuple[IndexList[4], IndexList[3]]
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!