For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
gated_delta
Gated DeltaNet recurrence kernel for Qwen3.5 β Pass 2 of two-pass prefill.
Implements the gated delta rule recurrence over a ragged (variable-length) batch of sequences. This is Pass 2 of the prefill path; it consumes the conv1d output produced by Pass 1 (gated_delta_conv1d_fwd).
The five steps of the gated delta rule at each token t for value-dim element vd_i and value head h are:
-
Apply per-head scalar decay to the entire state column: state_col[k] β decay[t,h] * state_col[k] for k in [0, KD)
-
Compute kv_memory by taking the dot product of the decayed state column with the L2-normalised key vector (summing over the key_dim axis): kv_memory_vd_i = Ξ£_k state_col[k] * key_normalised[t,h,k]
-
Compute the delta correction using beta and the value residual: delta_correction_vd_i = beta[t,h] * (value[t,h,vd_i] - kv_memory_vd_i)
-
Outer-product update of the state column with the key and delta: state_col[k] β state_col[k] + key_normalised[t,h,k] * delta_correction_vd_i
-
Read out the output by dotting the updated state with the scaled, L2-normalised query vector: output[t, h*VD + vd_i] = Ξ£_k state_col[k] * query_scaled[t,h,k]
All of steps 1β5 run over the key_dim loop (k = 0..KD-1) which is a compile- time constant. This allows the inner loop to be fully unrolled and the KD- element state column to live in GPU registers, eliminating shared-memory traffic.
L2 normalisation and Q scaling are fused into the kernel body. The raw Q/K/V vectors are read directly from the conv1d output (qkv_conv_output), with the channel layout: Q at [0..key_dim), K at [key_dim..2key_dim), V at [2key_dim..).
GQA (grouped query attention) is handled by computing the key head index as: key_head_idx = value_head_idx // heads_expansion_ratio
where heads_expansion_ratio = num_value_heads / num_key_heads is a runtime integer, so no compile-time specialisation per model is required.
Tensor shapesβ
Inputs:
qkv_conv_output : [total_seq_len, conv_dim] float32
Conv1d output from Pass 1. Channel layout:
Q: channels [0, key_dim)
K: channels [key_dim, 2key_dim)
V: channels [2key_dim, 2*key_dim + value_dim)
where key_dim = num_key_heads * key_head_dim
value_dim = num_value_heads * value_head_dim
conv_dim = key_dim * 2 + value_dim
decay_per_token : [total_seq_len, num_value_heads] float32
Per-token, per-head scalar decay factor (exp(-softplus) pre-applied).
beta_per_token : [total_seq_len, num_value_heads] float32
Per-token, per-head beta gate (sigmoid pre-applied).
recurrent_state : [max_slots, num_value_heads, key_head_dim, value_head_dim]
Mutable recurrent-state pool. The kernel reads/writes slot
slot_idx[batch_item] in place; all other slots are untouched.
Pool dtype is independent of the working dtype, so the caller can
keep per-token tensors at float32 while storing the pool at the
model's native dtype (typically bfloat16).
slot_idx : [batch_size] uint32
Pool slot index for each batch item.
input_row_offsets : [batch_size + 1] uint32
Ragged offsets: sequence b spans flat indices
[input_row_offsets[b], input_row_offsets[b+1]).
Outputs: recurrence_output : [total_seq_len, value_dim] float32 Flat output for all tokens. Indexed as output[flat_t, value_head_idx * value_head_dim + vd_element_idx]. (recurrent_state is mutated in place; there is no separate state-out tensor.)
Thread mapping (GPU)β
total_threads = batch_size * num_value_heads * value_head_dim Grid : ceildiv(total_threads, RECURRENCE_BLOCK_SIZE) blocks of size 1-D Block : RECURRENCE_BLOCK_SIZE threads
Thread decomposition: flat_thread_idx = block_idx * RECURRENCE_BLOCK_SIZE + thread_idx batch_item_idx = flat_thread_idx // (num_value_heads * value_head_dim) value_head_idx = (flat_thread_idx % (num_value_heads * value_head_dim)) // value_head_dim vd_element_idx = flat_thread_idx % value_head_dim key_head_idx = value_head_idx // heads_expansion_ratio
Each thread owns the KD-element column state_col[0..KD-1] = recurrent_state[slot_idx[batch_item], value_head, 0..KD-1, vd_element] in registers and iterates over its sequence sequentially.
Functionsβ
- β
gated_delta_recurrence_fwd_gpu: GPU kernel: slot-indexed gated delta rule recurrence.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!