For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
gated_delta_conv1d
Causal depthwise conv1d for the Gated DeltaNet two-pass prefill.
This is Pass 1 of the two-pass gated delta rule prefill path. It computes the causal 1-D convolution over a ragged (variable-length) batch of sequences and updates the per-sequence sliding-window conv state.
Unlike the existing causal_conv1d_varlen_fwd (which uses [dim, total_seqlen] layout for Mamba compatibility), this kernel uses [total_seqlen, conv_dim] layout to match the gated_deltanet.py convention where all per-token tensors are seqlen-first.
Tensor shapesβ
Inputs:
qkv_input_ragged : [total_seq_len, conv_dim] float32
Flat projected QKV input, all sequences concatenated.
conv_weight : [conv_dim, kernel_size] float32
Depthwise conv weights (one weight per channel per time offset).
conv_state : [max_slots, conv_dim, kernel_size-1]
Mutable sliding-window conv state pool. The kernel reads/writes
slot slot_idx[batch_item] in place; all other slots are
untouched. Slots within a single pool entry are ordered
oldest-to-newest: window slot 0 is the token at position -(K-1)
relative to the current sequence start. Pool dtype is independent
of the working dtype.
slot_idx : [batch_size] uint32
Pool slot index for each batch item.
input_row_offsets : [batch_size + 1] uint32
Exclusive prefix sums of sequence lengths. Sequence b spans
token indices [input_row_offsets[b], input_row_offsets[b+1]).
Outputs: conv_output_ragged : [total_seq_len, conv_dim] float32 Causal conv output in the same ragged layout as the input. (conv_state is mutated in place; there is no separate state-out tensor. Window slot j ends up holding the raw input at position seq_len - (kernel_size-1) + j within the sequence, carrying forward from the old window when seq_len is shorter.)
Thread mapping (GPU)β
Grid : (batch_size, ceildiv(conv_dim, CONV1D_BLOCK_DIM))
Block : (CONV1D_BLOCK_DIM,)
One thread per (batch_item, conv_channel). Each thread processes its
channel's full sequence sequentially, reading from conv_state at slot
slot_idx[batch_item] for the look-back that extends before the
current sequence.
Functionsβ
- β
gated_delta_conv1d_fwd_gpu: GPU kernel: slot-indexed causal depthwise conv1d over a ragged batch.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!