Mojo function
mla_prefill_plan
mla_prefill_plan[cache_t: KVCacheT](buffer_row_offsets: NDBuffer[uint32, 2, origin, shape, strides], cache_offsets: NDBuffer[uint32, 2, origin, shape, strides], buffer_lengths: NDBuffer[int32, 1, origin, shape, strides], input_row_offsets: NDBuffer[uint32, 1, origin, shape, strides], k_cache: cache_t, buffer_token_size: SIMD[uint32, 1], ctx: DeviceContext)
This calls a GPU kernel that plans how to process a batch of sequences with varying lengths using a fixed-size buffer.
Each sequence in the batch has some existing cached tokens and new input tokens. The kernel divides the total tokens into chunks of buffer_token_size.
For each chunk (iteration), it calculates: 1. Buffer offsets for each sequence in each chunk 2. Cache offsets for each sequence in each chunk 3. Total buffer lengths for each processing iteration
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!