IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

mla_decode_dispatch_scalars

Device-dispatched MLA decode dispatch-metadata scalars.

This module hosts a single device-generic entry point, mla_decode_dispatch_scalars, that computes the packed 3-int MLA decode dispatch metadata (batch_size, q_max_seq_len, num_partitions). It mirrors the MHA pattern (mha_decoding_num_partitions), which hides the HIP-vs-SM100 device dispatch behind one function so callers β€” in particular the Mojo->Python binding mla_dispatch_args_scalar β€” stay device-agnostic and carry no if ctx.api() branch.

The file sits above both per-device heuristics in the dependency graph and imports from both: the device-generic AMD/MHA heuristic (mha_decode_partition_heuristic) and the SM100/NVIDIA runtime heuristic (nvidia/sm100/mla_decode_dispatch). Co-locating the unified function in either heuristic file would invert the dependency direction (an SM100 import in a generic heuristic, or an AMD import in an SM100 file), so it lives in a device-generic location instead.

Functions​