Skip to main content

Mojo struct

MLADispatchScalarArgs

struct MLADispatchScalarArgs[num_heads: Int, _is_cache_length_accurate: Bool = False, is_fp8_kv: Bool = False]

Pre-computed MLA decode args for the legacy (non-capturable) path.

Owns a GPU buffer containing [batch_size, q_max_seq_len, num_partitions] and caches the host-side batch_size/q_max_seq_len pair needed by mla_decode_sm100_dispatch.

Usage::

var args = MLADispatchScalarArgs[num_heads=128](
    batch_size, max_cache_len, q_max_seq_len, ctx,
)
var gpu_lt = args.gpu_layout_tensor()
mla_decode_sm100_dispatch[...](
    ..., gpu_lt,
    args.batch_size, args.q_max_seq_len, max_cache_len,
    ctx,
)
_ = args  # keepalive

Fields

  • gpu_buf (DeviceBuffer[DType.int64]):
  • batch_size (Int):
  • q_max_seq_len (Int):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

MLAScalarArgsLT

comptime MLAScalarArgsLT = LayoutTensor[DType.int64, Layout.row_major(VariadicList(3)), MutAnyOrigin]

Methods

__init__

__init__(out self, batch_size: Int, max_cache_len: Int, q_max_seq_len: Int, ctx: DeviceContext)

gpu_layout_tensor

gpu_layout_tensor(self) -> MLADispatchScalarArgs[num_heads, _is_cache_length_accurate, is_fp8_kv].MLAScalarArgsLT

Returns:

MLADispatchScalarArgs

Was this page helpful?