IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

causal_conv1d_channel_first_fwd_gpu

causal_conv1d_channel_first_fwd_gpu[x_dtype: DType, weight_dtype: DType, output_dtype: DType, kNThreads: Int, kWidth: Int, kNElts: Int, bias_dtype: DType, x_LT: TensorLayout, weight_LT: TensorLayout, output_LT: TensorLayout, bias_LT: TensorLayout](batch: Int, dim: Int, seqlen: Int, width: Int, x: TileTensor[x_dtype, x_LT, MutExternalOrigin], weight: TileTensor[weight_dtype, weight_LT, MutExternalOrigin], output: TileTensor[output_dtype, output_LT, MutExternalOrigin], bias: TileTensor[bias_dtype, bias_LT, MutExternalOrigin], x_batch_stride: UInt32, x_c_stride: UInt32, x_l_stride: UInt32, weight_c_stride: UInt32, weight_width_stride: UInt32, out_batch_stride: UInt32, out_c_stride: UInt32, out_l_stride: UInt32, bias_stride: UInt32, silu_activation: Int8)

Optimized GPU implementation of causal conv1d for channel-first layout with bias.

Key optimizations:

  1. SIMD vectorization for input/output operations (kNElts elements per thread).
  2. Efficient memory access patterns with coalesced loads.
  3. Vectorized weight loading and computation for width=2 and width=4.
  4. Optimized activation function with SIMD operations.
  5. Better thread utilization and memory bandwidth usage.

Grid: (ceildiv(seqlen, kNThreads * kNElts), dim, batch) Block: kNThreads

Args:

  • ​batch (Int): Batch size.
  • ​dim (Int): Number of channels.
  • ​seqlen (Int): Sequence length.
  • ​width (Int): Kernel width (must match kWidth compile-time parameter).
  • ​x (TileTensor[x_dtype, x_LT, MutExternalOrigin]): Input tensor of shape (B, C, L).
  • ​weight (TileTensor[weight_dtype, weight_LT, MutExternalOrigin]): Weight tensor of shape (C, W).
  • ​output (TileTensor[output_dtype, output_LT, MutExternalOrigin]): Output tensor of shape (B, C, L).
  • ​bias (TileTensor[bias_dtype, bias_LT, MutExternalOrigin]): Bias tensor of shape (C,).
  • ​x_batch_stride (UInt32): Stride for the batch dimension of the input tensor.
  • ​x_c_stride (UInt32): Stride for the channel dimension of the input tensor.
  • ​x_l_stride (UInt32): Stride for the sequence length dimension of the input tensor.
  • ​weight_c_stride (UInt32): Stride for the channel dimension of the weight tensor.
  • ​weight_width_stride (UInt32): Stride for the width dimension of the weight tensor.
  • ​out_batch_stride (UInt32): Stride for the batch dimension of the output tensor.
  • ​out_c_stride (UInt32): Stride for the channel dimension of the output tensor.
  • ​out_l_stride (UInt32): Stride for the sequence length dimension of the output tensor.
  • ​bias_stride (UInt32): Stride for the channel dimension of the bias tensor.
  • ​silu_activation (Int8): Whether to apply SiLU activation (Int8: 0 or 1).