IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

mla_prefill_plan

def mla_prefill_plan[cache_t: KVCacheT](buffer_row_offsets: TileTensor[DType.uint32, address_space=buffer_row_offsets.address_space, linear_idx_type=buffer_row_offsets.linear_idx_type, element_size=buffer_row_offsets.element_size], cache_offsets: TileTensor[DType.uint32, address_space=cache_offsets.address_space, linear_idx_type=cache_offsets.linear_idx_type, element_size=cache_offsets.element_size], buffer_lengths: TileTensor[DType.int32, address_space=buffer_lengths.address_space, linear_idx_type=buffer_lengths.linear_idx_type, element_size=buffer_lengths.element_size], input_row_offsets: TileTensor[DType.uint32, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], k_cache: cache_t, buffer_token_size: UInt32, ctx: DeviceContext)

This calls a GPU kernel that plans how to process a batch of sequences with varying lengths using a fixed-size buffer.

Each sequence in the batch has some existing cached tokens and new input tokens. The kernel divides the total tokens into chunks of buffer_token_size.

For each chunk (iteration), it calculates: 1. Buffer offsets for each sequence in each chunk 2. Cache offsets for each sequence in each chunk 3. Total buffer lengths for each processing iteration