IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

gated_delta_recurrence_fwd_gpu

gated_delta_recurrence_fwd_gpu[work_dtype: DType, state_dtype: DType, KEY_HEAD_DIM: Int, VALUE_HEAD_DIM: Int, RECURRENCE_BLOCK_SIZE: Int, recurrence_output_LT: TensorLayout, qkv_conv_output_LT: TensorLayout, decay_per_token_LT: TensorLayout, beta_per_token_LT: TensorLayout, recurrent_state_LT: TensorLayout, slot_idx_LT: TensorLayout, input_row_offsets_LT: TensorLayout](total_threads: Int, batch_size: Int, total_seq_len: Int, num_value_heads: Int, num_key_heads: Int, key_dim: Int, value_dim: Int, conv_dim: Int, recurrence_output: TileTensor[work_dtype, recurrence_output_LT, MutExternalOrigin], recurrent_state: TileTensor[state_dtype, recurrent_state_LT, MutExternalOrigin], slot_idx: TileTensor[DType.uint32, slot_idx_LT, MutExternalOrigin], qkv_conv_output: TileTensor[work_dtype, qkv_conv_output_LT, MutExternalOrigin], decay_per_token: TileTensor[work_dtype, decay_per_token_LT, MutExternalOrigin], beta_per_token: TileTensor[work_dtype, beta_per_token_LT, MutExternalOrigin], input_row_offsets: TileTensor[DType.uint32, input_row_offsets_LT, MutExternalOrigin], qkv_conv_output_seqlen_stride: UInt32, qkv_conv_output_channel_stride: UInt32, per_token_seqlen_stride: UInt32, per_token_head_stride: UInt32, recurrent_state_slot_stride: UInt32, recurrent_state_value_head_stride: UInt32, recurrent_state_key_dim_stride: UInt32, recurrent_state_value_dim_stride: UInt32, recurrence_output_seqlen_stride: UInt32, recurrence_output_valuedim_stride: UInt32)

GPU kernel: slot-indexed gated delta rule recurrence.

The recurrent state lives in a single mutable pool of shape [max_slots, nv, KD, VD]; this kernel reads/writes pool slot slot_idx[batch_item_idx] for batch item batch_item_idx and avoids the gather/scatter copies the host-side state cache used to do. One thread per (batch_item, value_head, vd_element) triple; the KD-element state column lives entirely in registers.