For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
mamba_split_conv1d_scan_combined_cpu
def mamba_split_conv1d_scan_combined_cpu[kernel_dtype: DType, DSTATE: Int, zxbcdt_layout: TensorLayout, conv_weight_layout: TensorLayout, conv_bias_layout: TensorLayout, output_layout: TensorLayout, x_layout: TensorLayout, out_z_layout: TensorLayout, dt_layout: TensorLayout, A_layout: TensorLayout, B_layout: TensorLayout, C_layout: TensorLayout, D_layout: TensorLayout, z_layout: TensorLayout, delta_bias_layout: TensorLayout, rmsnorm_weight_layout: TensorLayout, outproj_weight_layout: TensorLayout, outproj_bias_layout: TensorLayout](batch: Int, seqlen: Int, dim: Int, nheads: Int, headdim: Int, ngroups: Int, width: Int, chunk_size: Int, delta_softplus: Int8, norm_before_gate: Int8, has_rmsnorm: Int8, has_outproj: Int8, zxbcdt: TileTensor[kernel_dtype, zxbcdt_layout, MutAnyOrigin], conv_weight: TileTensor[kernel_dtype, conv_weight_layout, MutAnyOrigin], conv_bias: TileTensor[kernel_dtype, conv_bias_layout, MutAnyOrigin], dt_bias: TileTensor[kernel_dtype, delta_bias_layout, MutAnyOrigin], A: TileTensor[kernel_dtype, A_layout, MutAnyOrigin], D: TileTensor[kernel_dtype, D_layout, MutAnyOrigin], x: TileTensor[kernel_dtype, x_layout, MutAnyOrigin], out_z: TileTensor[kernel_dtype, out_z_layout, MutAnyOrigin], dt: TileTensor[kernel_dtype, dt_layout, MutAnyOrigin], B: TileTensor[kernel_dtype, B_layout, MutAnyOrigin], C: TileTensor[kernel_dtype, C_layout, MutAnyOrigin], z: TileTensor[kernel_dtype, z_layout, MutAnyOrigin], rmsnorm_weight: TileTensor[kernel_dtype, rmsnorm_weight_layout, MutAnyOrigin], outproj_weight: TileTensor[kernel_dtype, outproj_weight_layout, MutAnyOrigin], outproj_bias: TileTensor[kernel_dtype, outproj_bias_layout, MutAnyOrigin], output: TileTensor[kernel_dtype, output_layout, MutAnyOrigin], epsilon: Scalar[kernel_dtype], ctx: Optional[DeviceContext] = None)
CPU kernel for mamba_split_conv1d_scan_combined operation.
Input zxbcdt structure:
- Channels 0 to dim-1: z (gating values)
- Channels dim to dim + 2ngroupsdstate - 1: xBC (x, B, C before conv)
- Channels 2dim + 2ngroups*dstate to end: dt (time step values)
After conv on xBC:
- Channels 0 to dim-1: x (input to scan)
- Channels dim to dim + ngroups*dstate - 1: B
- Channels dim + ngroupsdstate to dim + 2ngroups*dstate - 1: C
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!