For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
mla_decode_dispatch_scalars
Device-dispatched MLA decode dispatch-metadata scalars.
This module hosts a single device-generic entry point,
mla_decode_dispatch_scalars, that computes the packed 3-int MLA decode
dispatch metadata (batch_size, q_max_seq_len, num_partitions). It mirrors
the MHA pattern (mha_decoding_num_partitions), which hides the HIP-vs-SM100
device dispatch behind one function so callers β in particular the
Mojo->Python binding mla_dispatch_args_scalar β stay device-agnostic and
carry no if ctx.api() branch.
The file sits above both per-device heuristics in the dependency graph and
imports from both: the device-generic AMD/MHA heuristic
(mha_decode_partition_heuristic) and the SM100/NVIDIA runtime heuristic
(nvidia/sm100/mla_decode_dispatch). Co-locating the unified function in
either heuristic file would invert the dependency direction (an SM100 import
in a generic heuristic, or an AMD import in an SM100 file), so it lives in a
device-generic location instead.
Functionsβ
- β
mla_decode_dispatch_scalars: Compute the packed MLA decode dispatch metadata for the active device.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!