Skip to main content

Mojo function

compute_mla_dispatch_scalars

compute_mla_dispatch_scalars[num_heads: Int, _is_cache_length_accurate: Bool = False, is_fp8_kv: Bool = False](batch_size: Int, max_cache_valid_length: Int, q_max_seq_len: Int, sm_count: Int) -> Tuple[Int, Int, Int]

Pure computation of the packed 3-value MLA dispatch metadata.

Returns (batch_size, q_max_seq_len, num_partitions).

Returns:

Tuple

Was this page helpful?