IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

VarlenSelectiveScanFwd

struct VarlenSelectiveScanFwd[delta_softplus: Bool = False]

Variable-length selective scan forward pass.

Performs the selective scan computation for variable-length sequences that are concatenated together. Uses cumulative sequence lengths to identify sequence boundaries.

Tensor Shapes: - output: (dim, total_length) - Output tensor (or written to z if present) - ssm_states: (batch, dim, dstate) - SSM states (in/out) - u: (dim, total_length) - Input tensor - delta: (dim, total_length) - Time step tensor - A: (dim, dstate) - State transition matrix - B: (ngroups, dstate, total_length) - Input projection - C: (ngroups, dstate, total_length) - Output projection - D: (dim,) - Skip connection (optional, can be empty) - z: (dim, total_length) - Gating tensor (optional, can be empty) - delta_bias: (dim,) - Delta bias (optional, can be empty) - query_start_loc: (batch + 1,) - Cumulative sequence lengths - cache_indices: (batch,) - Indices into ssm_states (optional) - has_initial_state: (batch,) - Whether to use initial state (optional)

Parameters​

  • ​delta_softplus (Bool): If True, applies softplus activation to delta values.

Implemented traits​

AnyType, ImplicitlyDestructible

Methods​

execute​

static execute[dtype: DType, target: StringSlice[StaticConstantOrigin]](output: ManagedTensorSlice[Output, static_spec=output.static_spec], ssm_states: ManagedTensorSlice[Output, static_spec=ssm_states.static_spec], z: ManagedTensorSlice[Output, static_spec=z.static_spec], u: ManagedTensorSlice[Input, static_spec=u.static_spec], delta: ManagedTensorSlice[Input, static_spec=delta.static_spec], A: ManagedTensorSlice[Input, static_spec=A.static_spec], B: ManagedTensorSlice[Input, static_spec=B.static_spec], C: ManagedTensorSlice[Input, static_spec=C.static_spec], D: ManagedTensorSlice[Input, static_spec=D.static_spec], delta_bias: ManagedTensorSlice[Input, static_spec=delta_bias.static_spec], query_start_loc: ManagedTensorSlice[Input, static_spec=query_start_loc.static_spec], cache_indices: ManagedTensorSlice[Input, static_spec=cache_indices.static_spec], has_initial_state: ManagedTensorSlice[Input, static_spec=has_initial_state.static_spec], ctx: DeviceContext)

shape​

static shape[dtype: DType](u: ManagedTensorSlice[Input, static_spec=u.static_spec], delta: ManagedTensorSlice[Input, static_spec=delta.static_spec], A: ManagedTensorSlice[Input, static_spec=A.static_spec], B: ManagedTensorSlice[Input, static_spec=B.static_spec], C: ManagedTensorSlice[Input, static_spec=C.static_spec], D: ManagedTensorSlice[Input, static_spec=D.static_spec], delta_bias: ManagedTensorSlice[Input, static_spec=delta_bias.static_spec], query_start_loc: ManagedTensorSlice[Input, static_spec=query_start_loc.static_spec], cache_indices: ManagedTensorSlice[Input, static_spec=cache_indices.static_spec], has_initial_state: ManagedTensorSlice[Input, static_spec=has_initial_state.static_spec]) -> IndexList[2]

Returns:

IndexList[2]