IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

causal_conv1d_update_gpu_no_bias

causal_conv1d_update_gpu_no_bias[x_dtype: DType, conv_state_dtype: DType, weight_dtype: DType, output_dtype: DType, kNThreads: Int, x_LT: TensorLayout, conv_state_LT: TensorLayout, weight_LT: TensorLayout, output_LT: TensorLayout](batch: Int, dim: Int, seqlen: Int, width: Int, state_len: Int, x: TileTensor[x_dtype, x_LT, MutExternalOrigin], conv_state: TileTensor[conv_state_dtype, conv_state_LT, MutExternalOrigin], weight: TileTensor[weight_dtype, weight_LT, MutExternalOrigin], output: TileTensor[output_dtype, output_LT, MutExternalOrigin], x_batch_stride: UInt32, x_c_stride: UInt32, x_l_stride: UInt32, conv_state_batch_stride: UInt32, conv_state_c_stride: UInt32, conv_state_l_stride: UInt32, weight_c_stride: UInt32, weight_width_stride: UInt32, out_batch_stride: UInt32, out_c_stride: UInt32, out_l_stride: UInt32, silu_activation: Int8)

GPU kernel for causal conv1d update operation without bias (for autoregressive decode).

This kernel performs incremental updates to maintain convolution state for efficient autoregressive token generation. It processes a new input sequence and updates both the output and the internal convolution state.

Grid: (batch, ceildiv(dim, kNThreads)) Block: kNThreads

Args:

  • ​batch (Int): Batch size.
  • ​dim (Int): Number of channels.
  • ​seqlen (Int): Sequence length of the new input.
  • ​width (Int): Kernel width.
  • ​state_len (Int): Length of the convolution state buffer.
  • ​x (TileTensor[x_dtype, x_LT, MutExternalOrigin]): Input tensor of shape (B, C, L).
  • ​conv_state (TileTensor[conv_state_dtype, conv_state_LT, MutExternalOrigin]): Convolution state tensor of shape (B, C, state_len).
  • ​weight (TileTensor[weight_dtype, weight_LT, MutExternalOrigin]): Weight tensor of shape (C, W).
  • ​output (TileTensor[output_dtype, output_LT, MutExternalOrigin]): Output tensor of shape (B, C, L).
  • ​x_batch_stride (UInt32): Stride for the batch dimension of the input tensor.
  • ​x_c_stride (UInt32): Stride for the channel dimension of the input tensor.
  • ​x_l_stride (UInt32): Stride for the sequence length dimension of the input tensor.
  • ​conv_state_batch_stride (UInt32): Stride for the batch dimension of the conv state tensor.
  • ​conv_state_c_stride (UInt32): Stride for the channel dimension of the conv state tensor.
  • ​conv_state_l_stride (UInt32): Stride for the sequence length dimension of the conv state tensor.
  • ​weight_c_stride (UInt32): Stride for the channel dimension of the weight tensor.
  • ​weight_width_stride (UInt32): Stride for the width dimension of the weight tensor.
  • ​out_batch_stride (UInt32): Stride for the batch dimension of the output tensor.
  • ​out_c_stride (UInt32): Stride for the channel dimension of the output tensor.
  • ​out_l_stride (UInt32): Stride for the sequence length dimension of the output tensor.
  • ​silu_activation (Int8): Whether to apply SiLU activation (Int8: 0 or 1).