For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
causal_conv1d_update_gpu
causal_conv1d_update_gpu[x_dtype: DType, conv_state_dtype: DType, weight_dtype: DType, output_dtype: DType, bias_dtype: DType, kNThreads: Int, x_LT: TensorLayout, conv_state_LT: TensorLayout, weight_LT: TensorLayout, output_LT: TensorLayout, bias_LT: TensorLayout](batch: Int, dim: Int, seqlen: Int, width: Int, state_len: Int, x: TileTensor[x_dtype, x_LT, MutExternalOrigin], conv_state: TileTensor[conv_state_dtype, conv_state_LT, MutExternalOrigin], weight: TileTensor[weight_dtype, weight_LT, MutExternalOrigin], output: TileTensor[output_dtype, output_LT, MutExternalOrigin], bias: TileTensor[bias_dtype, bias_LT, MutExternalOrigin], x_batch_stride: UInt32, x_c_stride: UInt32, x_l_stride: UInt32, conv_state_batch_stride: UInt32, conv_state_c_stride: UInt32, conv_state_l_stride: UInt32, weight_c_stride: UInt32, weight_width_stride: UInt32, out_batch_stride: UInt32, out_c_stride: UInt32, out_l_stride: UInt32, silu_activation: Int8)
GPU kernel for causal conv1d update operation (for autoregressive decode).
This kernel performs incremental updates to maintain convolution state for efficient autoregressive token generation. It processes a new input sequence and updates both the output and the internal convolution state.
Grid: (batch, ceildiv(dim, kNThreads)) Block: kNThreads
Args:
- βbatch (
Int): Batch size. - βdim (
Int): Number of channels. - βseqlen (
Int): Sequence length of the new input. - βwidth (
Int): Kernel width. - βstate_len (
Int): Length of the convolution state buffer. - βx (
TileTensor[x_dtype, x_LT, MutExternalOrigin]): Input tensor of shape (B, C, L). - βconv_state (
TileTensor[conv_state_dtype, conv_state_LT, MutExternalOrigin]): Convolution state tensor of shape (B, C, state_len). - βweight (
TileTensor[weight_dtype, weight_LT, MutExternalOrigin]): Weight tensor of shape (C, W). - βoutput (
TileTensor[output_dtype, output_LT, MutExternalOrigin]): Output tensor of shape (B, C, L). - βbias (
TileTensor[bias_dtype, bias_LT, MutExternalOrigin]): Bias tensor of shape (C,). - βx_batch_stride (
UInt32): Stride for the batch dimension of the input tensor. - βx_c_stride (
UInt32): Stride for the channel dimension of the input tensor. - βx_l_stride (
UInt32): Stride for the sequence length dimension of the input tensor. - βconv_state_batch_stride (
UInt32): Stride for the batch dimension of the conv state tensor. - βconv_state_c_stride (
UInt32): Stride for the channel dimension of the conv state tensor. - βconv_state_l_stride (
UInt32): Stride for the sequence length dimension of the conv state tensor. - βweight_c_stride (
UInt32): Stride for the channel dimension of the weight tensor. - βweight_width_stride (
UInt32): Stride for the width dimension of the weight tensor. - βout_batch_stride (
UInt32): Stride for the batch dimension of the output tensor. - βout_c_stride (
UInt32): Stride for the channel dimension of the output tensor. - βout_l_stride (
UInt32): Stride for the sequence length dimension of the output tensor. - βsilu_activation (
Int8): Whether to apply SiLU activation (Int8: 0 or 1).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!