Skip to main content

Mojo function

lmcache_onload

lmcache_onload[dtype: DType, page_size: Int, num_kv_heads: Int, head_dim: Int, kv_dim: Int, target: StringSlice[StaticConstantOrigin] = "gpu"](paged_cache: LayoutTensor[dtype, Layout.row_major[6](), MutAnyOrigin], input: LayoutTensor[dtype, Layout.row_major[4](), MutAnyOrigin], slot_mapping: LayoutTensor[DType.int64, Layout.row_major[1](), MutAnyOrigin], start_token: Int, end_token: Int, ctx: DeviceContext)

Onload KV cache data from external contiguous format to MAX paged format.

Parameters:

  • dtype (DType): Data type of the cache.
  • page_size (Int): Number of tokens per page in the paged cache.
  • num_kv_heads (Int): Number of KV attention heads.
  • head_dim (Int): Dimension of each attention head.
  • kv_dim (Int): KV dimension (2 for standard K/V, 1 for MLA).
  • target (StringSlice): Target device ("gpu" or "cpu").

Args:

  • paged_cache (LayoutTensor): Destination tensor [total_num_blocks, kv_dim, num_layers, page_size, num_heads, head_dim].
  • input (LayoutTensor): Source tensor [kv_dim, num_layers, num_tokens, hidden_dim].
  • slot_mapping (LayoutTensor): Token to slot mapping [total_tokens].
  • start_token (Int): Starting token index in slot_mapping.
  • end_token (Int): Ending token index (exclusive) in slot_mapping.
  • ctx (DeviceContext): Device context for kernel launch.

Was this page helpful?