For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
SelectiveScanFwd
struct SelectiveScanFwd[delta_softplus: Bool = False]
Selective scan forward pass operation for Mamba SSM.
Performs the selective scan computation used in Mamba state space models. This is the core operation that processes sequences through the SSM.
Tensor Shapes: - output: (batch, dim, seqlen) - Output tensor - x: (batch, dim, num_chunks, 2*dstate) - Checkpoint tensor for chunking - out_z: (batch, dim, seqlen) - Gated output (if z is provided) - u: (batch, dim, seqlen) - Input tensor - delta: (batch, dim, seqlen) - Time step tensor - A: (dim, dstate) - State transition matrix - B: (batch, n_groups, dstate, seqlen) - Input projection - C: (batch, n_groups, dstate, seqlen) - Output projection - D: (dim,) - Skip connection (optional, can be empty) - z: (batch, dim, seqlen) - Gating tensor (optional, can be empty) - delta_bias: (dim,) - Delta bias (optional, can be empty)
Parametersβ
- βdelta_softplus (
Bool): If True, applies softplus activation to delta values.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
Methodsβ
executeβ
static execute[dtype: DType, target: StringSlice[StaticConstantOrigin]](output: ManagedTensorSlice[Output, static_spec=output.static_spec], x: ManagedTensorSlice[Output, static_spec=x.static_spec], out_z: ManagedTensorSlice[Output, static_spec=out_z.static_spec], u: ManagedTensorSlice[Input, static_spec=u.static_spec], delta: ManagedTensorSlice[Input, static_spec=delta.static_spec], A: ManagedTensorSlice[Input, static_spec=A.static_spec], B: ManagedTensorSlice[Input, static_spec=B.static_spec], C: ManagedTensorSlice[Input, static_spec=C.static_spec], D: ManagedTensorSlice[Input, static_spec=D.static_spec], z: ManagedTensorSlice[Input, static_spec=z.static_spec], delta_bias: ManagedTensorSlice[Input, static_spec=delta_bias.static_spec], ctx: DeviceContext)
shapeβ
static shape[dtype: DType](u: ManagedTensorSlice[Input, static_spec=u.static_spec], delta: ManagedTensorSlice[Input, static_spec=delta.static_spec], A: ManagedTensorSlice[Input, static_spec=A.static_spec], B: ManagedTensorSlice[Input, static_spec=B.static_spec], C: ManagedTensorSlice[Input, static_spec=C.static_spec], D: ManagedTensorSlice[Input, static_spec=D.static_spec], z: ManagedTensorSlice[Input, static_spec=z.static_spec], delta_bias: ManagedTensorSlice[Input, static_spec=delta_bias.static_spec]) -> IndexList[3]
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!