Mojo module
mla_decode_sm100_dispatch
comptime values
DEFAULT_NUM_PARTITIONS
comptime DEFAULT_NUM_PARTITIONS = NUM_PARTITIONS.get((len[Dict[Int, Int]](NUM_PARTITIONS) - 1), 0)
logger
comptime logger = Logger(stdout, "", False)
MAX_NUM_SPLITS
comptime MAX_NUM_SPLITS = 74
NUM_PARTITIONS
comptime NUM_PARTITIONS = Dict(List(VariadicList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Tuple()), List(VariadicList(1, 2, 4, 8, 16, 32, 37, 64, 72, 74), Tuple()), Tuple())
Structs
-
MLADispatchScalarArgs: Pre-computed MLA decode args for the legacy (non-capturable) path.
Functions
-
compute_mla_dispatch_scalars: Pure computation of the packed 3-value MLA dispatch metadata. -
launch_mla_sm100_decode_enqueue_kernel: -
launch_mla_sm100_decode_fp8_per_token_scale_rope_aware: Launch the FP8 per-token-scale rope-aware MLA decode kernel with split content/rope TMAs. -
launch_mla_sm100_decode_native_fp8: Launch the native FP8 MLA decode kernel with FP8 Q TMA. -
mla_decode_sm100_dispatch: -
mla_decode_sm100_sink_split_k:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!