Mojo function
lmcache_onload
lmcache_onload[dtype: DType, page_size: Int, num_kv_heads: Int, head_dim: Int, kv_dim: Int, target: StringSlice[StaticConstantOrigin] = StringSlice("gpu")](paged_cache: LayoutTensor[dtype, Layout.row_major[6](), MutAnyOrigin], input: LayoutTensor[dtype, Layout.row_major[4](), MutAnyOrigin], slot_mapping: LayoutTensor[DType.int64, Layout.row_major[1](), MutAnyOrigin], start_token: Int, end_token: Int, ctx: DeviceContext)
Onload KV cache data from external contiguous format to MAX paged format.
Parameters:
- βdtype (
DType): Data type of the cache. - βpage_size (
Int): Number of tokens per page in the paged cache. - βnum_kv_heads (
Int): Number of KV attention heads. - βhead_dim (
Int): Dimension of each attention head. - βkv_dim (
Int): KV dimension (2 for standard K/V, 1 for MLA). - βtarget (
StringSlice[StaticConstantOrigin]): Target device ("gpu" or "cpu").
Args:
- βpaged_cache (
LayoutTensor[dtype, Layout.row_major[6](), MutAnyOrigin]): Destination tensor [total_num_blocks, kv_dim, num_layers, page_size, num_heads, head_dim]. - βinput (
LayoutTensor[dtype, Layout.row_major[4](), MutAnyOrigin]): Source tensor [kv_dim, num_layers, num_tokens, hidden_dim]. - βslot_mapping (
LayoutTensor[DType.int64, Layout.row_major[1](), MutAnyOrigin]): Token to slot mapping [total_tokens]. - βstart_token (
Int): Starting token index in slot_mapping. - βend_token (
Int): Ending token index (exclusive) in slot_mapping. - βctx (
DeviceContext): Device context for kernel launch.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!