For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
mamba_split_conv1d_scan_combined_cpu
mamba_split_conv1d_scan_combined_cpu[kernel_dtype: DType, DSTATE: Int, zxbcdt_layout: Layout, conv_weight_layout: Layout, conv_bias_layout: Layout, output_layout: Layout, x_layout: Layout, out_z_layout: Layout, dt_layout: Layout, A_layout: Layout, B_layout: Layout, C_layout: Layout, D_layout: Layout, z_layout: Layout, delta_bias_layout: Layout, rmsnorm_weight_layout: Layout, outproj_weight_layout: Layout, outproj_bias_layout: Layout](batch: Int, seqlen: Int, dim: Int, nheads: Int, headdim: Int, ngroups: Int, width: Int, chunk_size: Int, delta_softplus: Int8, norm_before_gate: Int8, has_rmsnorm: Int8, has_outproj: Int8, zxbcdt: LayoutTensor[kernel_dtype, zxbcdt_layout, MutAnyOrigin], conv_weight: LayoutTensor[kernel_dtype, conv_weight_layout, MutAnyOrigin], conv_bias: LayoutTensor[kernel_dtype, conv_bias_layout, MutAnyOrigin], dt_bias: LayoutTensor[kernel_dtype, delta_bias_layout, MutAnyOrigin], A: LayoutTensor[kernel_dtype, A_layout, MutAnyOrigin], D: LayoutTensor[kernel_dtype, D_layout, MutAnyOrigin], x: LayoutTensor[kernel_dtype, x_layout, MutAnyOrigin], out_z: LayoutTensor[kernel_dtype, out_z_layout, MutAnyOrigin], dt: LayoutTensor[kernel_dtype, dt_layout, MutAnyOrigin], B: LayoutTensor[kernel_dtype, B_layout, MutAnyOrigin], C: LayoutTensor[kernel_dtype, C_layout, MutAnyOrigin], z: LayoutTensor[kernel_dtype, z_layout, MutAnyOrigin], rmsnorm_weight: LayoutTensor[kernel_dtype, rmsnorm_weight_layout, MutAnyOrigin], outproj_weight: LayoutTensor[kernel_dtype, outproj_weight_layout, MutAnyOrigin], outproj_bias: LayoutTensor[kernel_dtype, outproj_bias_layout, MutAnyOrigin], output: LayoutTensor[kernel_dtype, output_layout, MutAnyOrigin], epsilon: Scalar[kernel_dtype], ctx: Optional[DeviceContext] = None)
CPU kernel for mamba_split_conv1d_scan_combined operation.
Input zxbcdt structure:
- Channels 0 to dim-1: z (gating values)
- Channels dim to dim + 2ngroupsdstate - 1: xBC (x, B, C before conv)
- Channels 2dim + 2ngroups*dstate to end: dt (time step values)
After conv on xBC:
- Channels 0 to dim-1: x (input to scan)
- Channels dim to dim + ngroups*dstate - 1: B
- Channels dim + ngroupsdstate to dim + 2ngroups*dstate - 1: C
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!