Skip to main content

Mojo function

mla_prefill_plan

mla_prefill_plan[cache_t: KVCacheT](buffer_row_offsets: TileTensor[DType.uint32, buffer_row_offsets.LayoutType, buffer_row_offsets.origin, address_space=buffer_row_offsets.address_space, linear_idx_type=buffer_row_offsets.linear_idx_type, element_size=buffer_row_offsets.element_size], cache_offsets: TileTensor[DType.uint32, cache_offsets.LayoutType, cache_offsets.origin, address_space=cache_offsets.address_space, linear_idx_type=cache_offsets.linear_idx_type, element_size=cache_offsets.element_size], buffer_lengths: TileTensor[DType.int32, buffer_lengths.LayoutType, buffer_lengths.origin, address_space=buffer_lengths.address_space, linear_idx_type=buffer_lengths.linear_idx_type, element_size=buffer_lengths.element_size], input_row_offsets: TileTensor[DType.uint32, input_row_offsets.LayoutType, input_row_offsets.origin, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], k_cache: cache_t, buffer_token_size: UInt32, ctx: DeviceContext)

This calls a GPU kernel that plans how to process a batch of sequences with varying lengths using a fixed-size buffer.

Each sequence in the batch has some existing cached tokens and new input tokens. The kernel divides the total tokens into chunks of buffer_token_size.

For each chunk (iteration), it calculates: 1. Buffer offsets for each sequence in each chunk 2. Cache offsets for each sequence in each chunk 3. Total buffer lengths for each processing iteration

Was this page helpful?