IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MLADispatchScalarArgs

struct MLADispatchScalarArgs[num_heads: Int, _is_cache_length_accurate: Bool = False, is_fp8_kv: Bool = False]

Pre-computed MLA decode args for the legacy (non-capturable) path.

Owns a GPU buffer containing [batch_size, q_max_seq_len, num_partitions] and caches the host-side batch_size/q_max_seq_len pair needed by mla_decode_sm100_dispatch.

Usage::

var args = MLADispatchScalarArgs[num_heads=128](
    batch_size, max_cache_len, q_max_seq_len, ctx,
)
var gpu_lt = args.gpu_layout_tensor()
mla_decode_sm100_dispatch[...](
    ..., gpu_lt,
    args.batch_size, args.q_max_seq_len, max_cache_len,
    ctx,
)
_ = args  # keepalive

Fields​

  • ​gpu_buf (DeviceBuffer[DType.int64]):
  • ​batch_size (Int):
  • ​q_max_seq_len (Int):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

MLAScalarArgsLT​

comptime MLAScalarArgsLT = LayoutTensor[DType.int64, Layout.row_major(Int(3)), MutAnyOrigin]

Methods​

__init__​

def __init__(out self, batch_size: Int, max_cache_len: Int, q_max_seq_len: Int, ctx: DeviceContext)

gpu_layout_tensor​

def gpu_layout_tensor(self) -> Self.MLAScalarArgsLT

Returns:

Self.MLAScalarArgsLT

gpu_tile_tensor​

def gpu_tile_tensor(self) -> TileTensor[DType.int64, Layout[*?, *?], MutAnyOrigin]

Returns:

TileTensor[DType.int64, Layout[*?, *?], MutAnyOrigin]