Skip to main content

Mojo function

compute_mla_dispatch_scalar_args

compute_mla_dispatch_scalar_args[num_heads: Int, _is_cache_length_accurate: Bool = False, is_fp8_kv: Bool = False](output_ptr: UnsafePointer[Int64, MutAnyOrigin], batch_size: Int, max_cache_valid_length: Int, q_max_seq_len: Int, ctx: DeviceContext)

Compute the 4 scalar dispatch args and write them to the device buffer.

The output buffer layout is: [0] batch_size [1] q_max_seq_len [2] num_partitions [3] max_cache_valid_length

This is called once per device before the layer loop by the mo.mla.compute_dispatch_args.paged MOGG op.

Was this page helpful?