Mojo function
kv_cache_2m_iadd_dispatch
kv_cache_2m_iadd_dispatch[dtype: DType, collection_t: KVCollectionT, //, target: StringSlice[StaticConstantOrigin]](kv: LayoutTensor[dtype, kv.layout, kv.origin, element_layout=kv.element_layout, layout_int_type=kv.layout_int_type, linear_idx_type=kv.linear_idx_type, masked=kv.masked, alignment=kv.alignment], cache: collection_t, input_row_offsets: LayoutTensor[DType.uint32, input_row_offsets.layout, input_row_offsets.origin, element_layout=input_row_offsets.element_layout, layout_int_type=input_row_offsets.layout_int_type, linear_idx_type=input_row_offsets.linear_idx_type, masked=input_row_offsets.masked, alignment=input_row_offsets.alignment], lora_end_idx: LayoutTensor[DType.int64, lora_end_idx.layout, lora_end_idx.origin, element_layout=lora_end_idx.element_layout, layout_int_type=lora_end_idx.layout_int_type, linear_idx_type=lora_end_idx.linear_idx_type, masked=lora_end_idx.masked, alignment=lora_end_idx.alignment], batch_seq_len: LayoutTensor[DType.int64, batch_seq_len.layout, batch_seq_len.origin, element_layout=batch_seq_len.element_layout, layout_int_type=batch_seq_len.layout_int_type, linear_idx_type=batch_seq_len.linear_idx_type, masked=batch_seq_len.masked, alignment=batch_seq_len.alignment], layer_idx: UInt32, ctx: Optional[DeviceContext])
In-place add to paged KV cache with concatenated K/V layout. This kernel is only used for LoRA.
Performs an in-place addition of new key-value projections to paged KV cache.
The input tensor a uses a "2m" layout where keys and values are concatenated:
rows [0, m) contain keys and rows [m, 2m) contain values, where m is the number
of tokens. We use the lora_end_idx to index into the K or V tensor.
We call this value m since this value will be a subset of the
total tokens in the batch. We write tokens to K as [0, m) and V as [m, 2m).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!