Mojo function
lmcache_onload
lmcache_onload[dtype: DType, page_size: Int, num_kv_heads: Int, head_dim: Int, kv_dim: Int, target: StringSlice[StaticConstantOrigin] = "gpu"](paged_cache: LayoutTensor[dtype, Layout.row_major[6](), MutAnyOrigin], input: LayoutTensor[dtype, Layout.row_major[4](), MutAnyOrigin], slot_mapping: LayoutTensor[DType.int64, Layout.row_major[1](), MutAnyOrigin], start_token: Int, end_token: Int, ctx: DeviceContext)
Onload KV cache data from external contiguous format to MAX paged format.
Parameters:
- dtype (
DType): Data type of the cache. - page_size (
Int): Number of tokens per page in the paged cache. - num_kv_heads (
Int): Number of KV attention heads. - head_dim (
Int): Dimension of each attention head. - kv_dim (
Int): KV dimension (2 for standard K/V, 1 for MLA). - target (
StringSlice): Target device ("gpu" or "cpu").
Args:
- paged_cache (
LayoutTensor): Destination tensor [total_num_blocks, kv_dim, num_layers, page_size, num_heads, head_dim]. - input (
LayoutTensor): Source tensor [kv_dim, num_layers, num_tokens, hidden_dim]. - slot_mapping (
LayoutTensor): Token to slot mapping [total_tokens]. - start_token (
Int): Starting token index in slot_mapping. - end_token (
Int): Ending token index (exclusive) in slot_mapping. - ctx (
DeviceContext): Device context for kernel launch.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!