IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SelectiveScanFwd

struct SelectiveScanFwd[delta_softplus: Bool = False]

Selective scan forward pass operation for Mamba SSM.

Performs the selective scan computation used in Mamba state space models. This is the core operation that processes sequences through the SSM.

Tensor Shapes: - output: (batch, dim, seqlen) - Output tensor - x: (batch, dim, num_chunks, 2*dstate) - Checkpoint tensor for chunking - out_z: (batch, dim, seqlen) - Gated output (if z is provided) - u: (batch, dim, seqlen) - Input tensor - delta: (batch, dim, seqlen) - Time step tensor - A: (dim, dstate) - State transition matrix - B: (batch, n_groups, dstate, seqlen) - Input projection - C: (batch, n_groups, dstate, seqlen) - Output projection - D: (dim,) - Skip connection (optional, can be empty) - z: (batch, dim, seqlen) - Gating tensor (optional, can be empty) - delta_bias: (dim,) - Delta bias (optional, can be empty)

Parameters​

  • ​delta_softplus (Bool): If True, applies softplus activation to delta values.

Implemented traits​

AnyType, ImplicitlyDestructible

Methods​

execute​

static execute[dtype: DType, target: StringSlice[StaticConstantOrigin]](output: ManagedTensorSlice[Output, static_spec=output.static_spec], x: ManagedTensorSlice[Output, static_spec=x.static_spec], out_z: ManagedTensorSlice[Output, static_spec=out_z.static_spec], u: ManagedTensorSlice[Input, static_spec=u.static_spec], delta: ManagedTensorSlice[Input, static_spec=delta.static_spec], A: ManagedTensorSlice[Input, static_spec=A.static_spec], B: ManagedTensorSlice[Input, static_spec=B.static_spec], C: ManagedTensorSlice[Input, static_spec=C.static_spec], D: ManagedTensorSlice[Input, static_spec=D.static_spec], z: ManagedTensorSlice[Input, static_spec=z.static_spec], delta_bias: ManagedTensorSlice[Input, static_spec=delta_bias.static_spec], ctx: DeviceContext)

shape​

static shape[dtype: DType](u: ManagedTensorSlice[Input, static_spec=u.static_spec], delta: ManagedTensorSlice[Input, static_spec=delta.static_spec], A: ManagedTensorSlice[Input, static_spec=A.static_spec], B: ManagedTensorSlice[Input, static_spec=B.static_spec], C: ManagedTensorSlice[Input, static_spec=C.static_spec], D: ManagedTensorSlice[Input, static_spec=D.static_spec], z: ManagedTensorSlice[Input, static_spec=z.static_spec], delta_bias: ManagedTensorSlice[Input, static_spec=delta_bias.static_spec]) -> IndexList[3]

Returns:

IndexList[3]