IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

gated_delta_conv1d

Causal depthwise conv1d for the Gated DeltaNet two-pass prefill.

This is Pass 1 of the two-pass gated delta rule prefill path. It computes the causal 1-D convolution over a ragged (variable-length) batch of sequences and updates the per-sequence sliding-window conv state.

Unlike the existing causal_conv1d_varlen_fwd (which uses [dim, total_seqlen] layout for Mamba compatibility), this kernel uses [total_seqlen, conv_dim] layout to match the gated_deltanet.py convention where all per-token tensors are seqlen-first.

Tensor shapes​

Inputs: qkv_input_ragged : [total_seq_len, conv_dim] float32 Flat projected QKV input, all sequences concatenated. conv_weight : [conv_dim, kernel_size] float32 Depthwise conv weights (one weight per channel per time offset). conv_state : [max_slots, conv_dim, kernel_size-1] Mutable sliding-window conv state pool. The kernel reads/writes slot slot_idx[batch_item] in place; all other slots are untouched. Slots within a single pool entry are ordered oldest-to-newest: window slot 0 is the token at position -(K-1) relative to the current sequence start. Pool dtype is independent of the working dtype. slot_idx : [batch_size] uint32 Pool slot index for each batch item. input_row_offsets : [batch_size + 1] uint32 Exclusive prefix sums of sequence lengths. Sequence b spans token indices [input_row_offsets[b], input_row_offsets[b+1]).

Outputs: conv_output_ragged : [total_seq_len, conv_dim] float32 Causal conv output in the same ragged layout as the input. (conv_state is mutated in place; there is no separate state-out tensor. Window slot j ends up holding the raw input at position seq_len - (kernel_size-1) + j within the sequence, carrying forward from the old window when seq_len is shorter.)

Thread mapping (GPU)​

Grid : (batch_size, ceildiv(conv_dim, CONV1D_BLOCK_DIM)) Block : (CONV1D_BLOCK_DIM,) One thread per (batch_item, conv_channel). Each thread processes its channel's full sequence sequentially, reading from conv_state at slot slot_idx[batch_item] for the look-back that extends before the current sequence.

Functions​