Mojo function
compute_mla_dispatch_scalars
compute_mla_dispatch_scalars[num_heads: Int, _is_cache_length_accurate: Bool = False, is_fp8_kv: Bool = False](batch_size: Int, max_cache_valid_length: Int, q_max_seq_len: Int, sm_count: Int) -> Tuple[Int, Int, Int]
Pure computation of the packed 3-value MLA dispatch metadata.
Returns (batch_size, q_max_seq_len, num_partitions).
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!