Mojo function
compute_mla_dispatch_scalar_args
compute_mla_dispatch_scalar_args[num_heads: Int, _is_cache_length_accurate: Bool = False, is_fp8_kv: Bool = False](output_ptr: UnsafePointer[Int64, MutAnyOrigin], batch_size: Int, max_cache_valid_length: Int, q_max_seq_len: Int, ctx: DeviceContext)
Compute the 4 scalar dispatch args and write them to the device buffer.
The output buffer layout is: [0] batch_size [1] q_max_seq_len [2] num_partitions [3] max_cache_valid_length
This is called once per device before the layer loop by the
mo.mla.compute_dispatch_args.paged MOGG op.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!