For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
VarlenSelectiveScanFwd
struct VarlenSelectiveScanFwd[delta_softplus: Bool = False]
Variable-length selective scan forward pass.
Performs the selective scan computation for variable-length sequences that are concatenated together. Uses cumulative sequence lengths to identify sequence boundaries.
Tensor Shapes: - output: (dim, total_length) - Output tensor (or written to z if present) - ssm_states: (batch, dim, dstate) - SSM states (in/out) - u: (dim, total_length) - Input tensor - delta: (dim, total_length) - Time step tensor - A: (dim, dstate) - State transition matrix - B: (ngroups, dstate, total_length) - Input projection - C: (ngroups, dstate, total_length) - Output projection - D: (dim,) - Skip connection (optional, can be empty) - z: (dim, total_length) - Gating tensor (optional, can be empty) - delta_bias: (dim,) - Delta bias (optional, can be empty) - query_start_loc: (batch + 1,) - Cumulative sequence lengths - cache_indices: (batch,) - Indices into ssm_states (optional) - has_initial_state: (batch,) - Whether to use initial state (optional)
Parametersβ
- βdelta_softplus (
Bool): If True, applies softplus activation to delta values.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
Methodsβ
executeβ
static execute[dtype: DType, target: StringSlice[StaticConstantOrigin]](output: ManagedTensorSlice[Output, static_spec=output.static_spec], ssm_states: ManagedTensorSlice[Output, static_spec=ssm_states.static_spec], z: ManagedTensorSlice[Output, static_spec=z.static_spec], u: ManagedTensorSlice[Input, static_spec=u.static_spec], delta: ManagedTensorSlice[Input, static_spec=delta.static_spec], A: ManagedTensorSlice[Input, static_spec=A.static_spec], B: ManagedTensorSlice[Input, static_spec=B.static_spec], C: ManagedTensorSlice[Input, static_spec=C.static_spec], D: ManagedTensorSlice[Input, static_spec=D.static_spec], delta_bias: ManagedTensorSlice[Input, static_spec=delta_bias.static_spec], query_start_loc: ManagedTensorSlice[Input, static_spec=query_start_loc.static_spec], cache_indices: ManagedTensorSlice[Input, static_spec=cache_indices.static_spec], has_initial_state: ManagedTensorSlice[Input, static_spec=has_initial_state.static_spec], ctx: DeviceContext)
shapeβ
static shape[dtype: DType](u: ManagedTensorSlice[Input, static_spec=u.static_spec], delta: ManagedTensorSlice[Input, static_spec=delta.static_spec], A: ManagedTensorSlice[Input, static_spec=A.static_spec], B: ManagedTensorSlice[Input, static_spec=B.static_spec], C: ManagedTensorSlice[Input, static_spec=C.static_spec], D: ManagedTensorSlice[Input, static_spec=D.static_spec], delta_bias: ManagedTensorSlice[Input, static_spec=delta_bias.static_spec], query_start_loc: ManagedTensorSlice[Input, static_spec=query_start_loc.static_spec], cache_indices: ManagedTensorSlice[Input, static_spec=cache_indices.static_spec], has_initial_state: ManagedTensorSlice[Input, static_spec=has_initial_state.static_spec]) -> IndexList[2]
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!