For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
gated_delta_conv1d_fwd_gpu
gated_delta_conv1d_fwd_gpu[work_dtype: DType, state_dtype: DType, KERNEL_SIZE: Int, CONV1D_BLOCK_DIM: Int, qkv_input_ragged_LT: TensorLayout, conv_weight_LT: TensorLayout, conv_state_LT: TensorLayout, slot_idx_LT: TensorLayout, input_row_offsets_LT: TensorLayout, conv_output_ragged_LT: TensorLayout](batch_size: Int, total_seq_len: Int, conv_dim: Int, qkv_input_ragged: TileTensor[work_dtype, qkv_input_ragged_LT, MutExternalOrigin], conv_weight: TileTensor[work_dtype, conv_weight_LT, MutExternalOrigin], conv_state: TileTensor[state_dtype, conv_state_LT, MutExternalOrigin], slot_idx: TileTensor[DType.uint32, slot_idx_LT, MutExternalOrigin], input_row_offsets: TileTensor[DType.uint32, input_row_offsets_LT, MutExternalOrigin], conv_output_ragged: TileTensor[work_dtype, conv_output_ragged_LT, MutExternalOrigin], qkv_input_seqlen_stride: UInt32, qkv_input_channel_stride: UInt32, conv_weight_channel_stride: UInt32, conv_weight_offset_stride: UInt32, conv_state_pool_stride: UInt32, conv_state_channel_stride: UInt32, conv_state_window_stride: UInt32, conv_output_seqlen_stride: UInt32, conv_output_channel_stride: UInt32)
GPU kernel: slot-indexed causal depthwise conv1d over a ragged batch.
The conv state lives in a single mutable pool of shape
[max_slots, conv_dim, K-1]; the kernel reads/writes slot
slot_idx[batch_item_idx] for batch item batch_item_idx. Reads of
the old window during look-back precede all writes of the new window, so
in-place mutation is safe. One thread handles one (batch_item,
conv_channel) pair for the entire sequence; the channel's K weights live
in registers across the token loop.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!