For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
causal_conv1d_update_cpu
causal_conv1d_update_cpu[x_dtype: DType, conv_state_dtype: DType, weight_dtype: DType, output_dtype: DType, bias_dtype: DType](batch: Int, dim: Int, seqlen: Int, width: Int, state_len: Int, x: TileTensor[x_dtype, address_space=x.address_space, linear_idx_type=x.linear_idx_type, element_size=x.element_size], conv_state: TileTensor[conv_state_dtype, address_space=conv_state.address_space, linear_idx_type=conv_state.linear_idx_type, element_size=conv_state.element_size], weight: TileTensor[weight_dtype, address_space=weight.address_space, linear_idx_type=weight.linear_idx_type, element_size=weight.element_size], output: TileTensor[output_dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], bias: TileTensor[bias_dtype, address_space=bias.address_space, linear_idx_type=bias.linear_idx_type, element_size=bias.element_size], x_batch_stride: UInt32, x_c_stride: UInt32, x_l_stride: UInt32, conv_state_batch_stride: UInt32, conv_state_c_stride: UInt32, conv_state_l_stride: UInt32, weight_c_stride: UInt32, weight_width_stride: UInt32, out_batch_stride: UInt32, out_c_stride: UInt32, out_l_stride: UInt32, silu_activation: Bool)
CPU implementation of causal conv1d update for incremental inference.
This kernel:
- Concatenates conv_state with x to form a sliding window
- Computes convolution output for the new positions
- Updates conv_state with the new values from x
Simple mode (no circular buffer):
- conv_state holds the last (state_len) values
- New x values are appended, old values are shifted out
Args:
- βbatch (
Int): Batch size. - βdim (
Int): Number of channels. - βseqlen (
Int): Sequence length of input x (typically 1). - βwidth (
Int): Kernel width. - βstate_len (
Int): Length of conv_state (>= width - 1). - βx (
TileTensor[x_dtype, address_space=x.address_space, linear_idx_type=x.linear_idx_type, element_size=x.element_size]): Input tensor. - βconv_state (
TileTensor[conv_state_dtype, address_space=conv_state.address_space, linear_idx_type=conv_state.linear_idx_type, element_size=conv_state.element_size]): Convolution state buffer (modified in-place). - βweight (
TileTensor[weight_dtype, address_space=weight.address_space, linear_idx_type=weight.linear_idx_type, element_size=weight.element_size]): Convolution weights. - βoutput (
TileTensor[output_dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Output tensor. - βbias (
TileTensor[bias_dtype, address_space=bias.address_space, linear_idx_type=bias.linear_idx_type, element_size=bias.element_size]): Bias tensor. - βx_batch_stride (
UInt32): Stride for batch dimension in x. - βx_c_stride (
UInt32): Stride for channel dimension in x. - βx_l_stride (
UInt32): Stride for sequence length dimension in x. - βconv_state_batch_stride (
UInt32): Stride for batch dimension in conv_state. - βconv_state_c_stride (
UInt32): Stride for channel dimension in conv_state. - βconv_state_l_stride (
UInt32): Stride for state length dimension in conv_state. - βweight_c_stride (
UInt32): Stride for channel dimension in weight. - βweight_width_stride (
UInt32): Stride for kernel width dimension in weight. - βout_batch_stride (
UInt32): Stride for batch dimension in output. - βout_c_stride (
UInt32): Stride for channel dimension in output. - βout_l_stride (
UInt32): Stride for sequence length dimension in output. - βsilu_activation (
Bool): Whether to apply SiLU activation.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!