For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
causal_conv1d_channel_first_fwd_gpu
causal_conv1d_channel_first_fwd_gpu[x_dtype: DType, weight_dtype: DType, output_dtype: DType, kNThreads: Int, kWidth: Int, kNElts: Int, bias_dtype: DType, x_LT: TensorLayout, weight_LT: TensorLayout, output_LT: TensorLayout, bias_LT: TensorLayout](batch: Int, dim: Int, seqlen: Int, width: Int, x: TileTensor[x_dtype, x_LT, MutExternalOrigin], weight: TileTensor[weight_dtype, weight_LT, MutExternalOrigin], output: TileTensor[output_dtype, output_LT, MutExternalOrigin], bias: TileTensor[bias_dtype, bias_LT, MutExternalOrigin], x_batch_stride: UInt32, x_c_stride: UInt32, x_l_stride: UInt32, weight_c_stride: UInt32, weight_width_stride: UInt32, out_batch_stride: UInt32, out_c_stride: UInt32, out_l_stride: UInt32, bias_stride: UInt32, silu_activation: Int8)
Optimized GPU implementation of causal conv1d for channel-first layout with bias.
Key optimizations:
- SIMD vectorization for input/output operations (kNElts elements per thread).
- Efficient memory access patterns with coalesced loads.
- Vectorized weight loading and computation for width=2 and width=4.
- Optimized activation function with SIMD operations.
- Better thread utilization and memory bandwidth usage.
Grid: (ceildiv(seqlen, kNThreads * kNElts), dim, batch) Block: kNThreads
Args:
- βbatch (
Int): Batch size. - βdim (
Int): Number of channels. - βseqlen (
Int): Sequence length. - βwidth (
Int): Kernel width (must match kWidth compile-time parameter). - βx (
TileTensor[x_dtype, x_LT, MutExternalOrigin]): Input tensor of shape (B, C, L). - βweight (
TileTensor[weight_dtype, weight_LT, MutExternalOrigin]): Weight tensor of shape (C, W). - βoutput (
TileTensor[output_dtype, output_LT, MutExternalOrigin]): Output tensor of shape (B, C, L). - βbias (
TileTensor[bias_dtype, bias_LT, MutExternalOrigin]): Bias tensor of shape (C,). - βx_batch_stride (
UInt32): Stride for the batch dimension of the input tensor. - βx_c_stride (
UInt32): Stride for the channel dimension of the input tensor. - βx_l_stride (
UInt32): Stride for the sequence length dimension of the input tensor. - βweight_c_stride (
UInt32): Stride for the channel dimension of the weight tensor. - βweight_width_stride (
UInt32): Stride for the width dimension of the weight tensor. - βout_batch_stride (
UInt32): Stride for the batch dimension of the output tensor. - βout_c_stride (
UInt32): Stride for the channel dimension of the output tensor. - βout_l_stride (
UInt32): Stride for the sequence length dimension of the output tensor. - βbias_stride (
UInt32): Stride for the channel dimension of the bias tensor. - βsilu_activation (
Int8): Whether to apply SiLU activation (Int8: 0 or 1).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!