Skip to main content

Mojo function

k_grouped_matmul_ragged_paged

k_grouped_matmul_ragged_paged[dtype: DType, target: StringSlice[StaticConstantOrigin]](a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 3, origin, shape], input_row_offsets: NDBuffer[uint32, 1, origin, shape, strides], ids: NDBuffer[uint32, 1, origin, shape, strides], max_seq_len: Int, active_ids: Int, kv_collection: PagedKVCacheCollection[dtype_, kv_params_, page_size], layer_idx: SIMD[uint32, 1], ctx: DeviceContextPtr)

Performs a matmul, writing the output into a mutable PagedKVCacheCollection object.

NOTE: This function is a additive against the KV cache and not a typical store operation.

Args:

  • a (NDBuffer[dtype, 2, origin, shape]): Input tensor with shape (sum(seq_lens), input_dim).
  • b (NDBuffer[dtype, 3, origin, shape]): Weight tensor with shape (num_experts, output_dim, input_dim).
  • input_row_offsets (NDBuffer[uint32, 1, origin, shape, strides]): Tensor with shape (batch_size + 1,) denoting the start of each sequence along the seq_len dimension.
  • ids (NDBuffer[uint32, 1, origin, shape, strides]): Expert IDs tensor.
  • max_seq_len (Int): Maximum sequence length per expert.
  • active_ids (Int): Number of active experts.
  • kv_collection (PagedKVCacheCollection[dtype_, kv_params_, page_size]): The historical KVCache for keys and values. The KVCache for this layer is retrieved via layer_idx.
  • layer_idx (SIMD[uint32, 1]): The index of the layer being executed. Used to retrieve the KVCache for the given layer from kv_collection.
  • ctx (DeviceContextPtr): The call context pointer, passed by the graph compiler.

Was this page helpful?