IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

gated_delta_conv1d_fwd_gpu

gated_delta_conv1d_fwd_gpu[work_dtype: DType, state_dtype: DType, KERNEL_SIZE: Int, CONV1D_BLOCK_DIM: Int, qkv_input_ragged_LT: TensorLayout, conv_weight_LT: TensorLayout, conv_state_LT: TensorLayout, slot_idx_LT: TensorLayout, input_row_offsets_LT: TensorLayout, conv_output_ragged_LT: TensorLayout](batch_size: Int, total_seq_len: Int, conv_dim: Int, qkv_input_ragged: TileTensor[work_dtype, qkv_input_ragged_LT, MutExternalOrigin], conv_weight: TileTensor[work_dtype, conv_weight_LT, MutExternalOrigin], conv_state: TileTensor[state_dtype, conv_state_LT, MutExternalOrigin], slot_idx: TileTensor[DType.uint32, slot_idx_LT, MutExternalOrigin], input_row_offsets: TileTensor[DType.uint32, input_row_offsets_LT, MutExternalOrigin], conv_output_ragged: TileTensor[work_dtype, conv_output_ragged_LT, MutExternalOrigin], qkv_input_seqlen_stride: UInt32, qkv_input_channel_stride: UInt32, conv_weight_channel_stride: UInt32, conv_weight_offset_stride: UInt32, conv_state_pool_stride: UInt32, conv_state_channel_stride: UInt32, conv_state_window_stride: UInt32, conv_output_seqlen_stride: UInt32, conv_output_channel_stride: UInt32)

GPU kernel: slot-indexed causal depthwise conv1d over a ragged batch.

The conv state lives in a single mutable pool of shape [max_slots, conv_dim, K-1]; the kernel reads/writes slot slot_idx[batch_item_idx] for batch item batch_item_idx. Reads of the old window during look-back precede all writes of the new window, so in-place mutation is safe. One thread handles one (batch_item, conv_channel) pair for the entire sequence; the channel's K weights live in registers across the token loop.