IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

gated_delta

Gated DeltaNet recurrence kernel for Qwen3.5 β€” Pass 2 of two-pass prefill.

Implements the gated delta rule recurrence over a ragged (variable-length) batch of sequences. This is Pass 2 of the prefill path; it consumes the conv1d output produced by Pass 1 (gated_delta_conv1d_fwd).

The five steps of the gated delta rule at each token t for value-dim element vd_i and value head h are:

  1. Apply per-head scalar decay to the entire state column: state_col[k] ← decay[t,h] * state_col[k] for k in [0, KD)

  2. Compute kv_memory by taking the dot product of the decayed state column with the L2-normalised key vector (summing over the key_dim axis): kv_memory_vd_i = Ξ£_k state_col[k] * key_normalised[t,h,k]

  3. Compute the delta correction using beta and the value residual: delta_correction_vd_i = beta[t,h] * (value[t,h,vd_i] - kv_memory_vd_i)

  4. Outer-product update of the state column with the key and delta: state_col[k] ← state_col[k] + key_normalised[t,h,k] * delta_correction_vd_i

  5. Read out the output by dotting the updated state with the scaled, L2-normalised query vector: output[t, h*VD + vd_i] = Ξ£_k state_col[k] * query_scaled[t,h,k]

All of steps 1–5 run over the key_dim loop (k = 0..KD-1) which is a compile- time constant. This allows the inner loop to be fully unrolled and the KD- element state column to live in GPU registers, eliminating shared-memory traffic.

L2 normalisation and Q scaling are fused into the kernel body. The raw Q/K/V vectors are read directly from the conv1d output (qkv_conv_output), with the channel layout: Q at [0..key_dim), K at [key_dim..2key_dim), V at [2key_dim..).

GQA (grouped query attention) is handled by computing the key head index as: key_head_idx = value_head_idx // heads_expansion_ratio

where heads_expansion_ratio = num_value_heads / num_key_heads is a runtime integer, so no compile-time specialisation per model is required.

Tensor shapes​

Inputs: qkv_conv_output : [total_seq_len, conv_dim] float32 Conv1d output from Pass 1. Channel layout: Q: channels [0, key_dim) K: channels [key_dim, 2key_dim) V: channels [2key_dim, 2*key_dim + value_dim) where key_dim = num_key_heads * key_head_dim value_dim = num_value_heads * value_head_dim conv_dim = key_dim * 2 + value_dim decay_per_token : [total_seq_len, num_value_heads] float32 Per-token, per-head scalar decay factor (exp(-softplus) pre-applied). beta_per_token : [total_seq_len, num_value_heads] float32 Per-token, per-head beta gate (sigmoid pre-applied). recurrent_state : [max_slots, num_value_heads, key_head_dim, value_head_dim] Mutable recurrent-state pool. The kernel reads/writes slot slot_idx[batch_item] in place; all other slots are untouched. Pool dtype is independent of the working dtype, so the caller can keep per-token tensors at float32 while storing the pool at the model's native dtype (typically bfloat16). slot_idx : [batch_size] uint32 Pool slot index for each batch item. input_row_offsets : [batch_size + 1] uint32 Ragged offsets: sequence b spans flat indices [input_row_offsets[b], input_row_offsets[b+1]).

Outputs: recurrence_output : [total_seq_len, value_dim] float32 Flat output for all tokens. Indexed as output[flat_t, value_head_idx * value_head_dim + vd_element_idx]. (recurrent_state is mutated in place; there is no separate state-out tensor.)

Thread mapping (GPU)​

total_threads = batch_size * num_value_heads * value_head_dim Grid : ceildiv(total_threads, RECURRENCE_BLOCK_SIZE) blocks of size 1-D Block : RECURRENCE_BLOCK_SIZE threads

Thread decomposition: flat_thread_idx = block_idx * RECURRENCE_BLOCK_SIZE + thread_idx batch_item_idx = flat_thread_idx // (num_value_heads * value_head_dim) value_head_idx = (flat_thread_idx % (num_value_heads * value_head_dim)) // value_head_dim vd_element_idx = flat_thread_idx % value_head_dim key_head_idx = value_head_idx // heads_expansion_ratio

Each thread owns the KD-element column state_col[0..KD-1] = recurrent_state[slot_idx[batch_item], value_head, 0..KD-1, vd_element] in registers and iterates over its sequence sequentially.

Functions​