Skip to main content

Mojo struct

MLADispatchScalarArgs

struct MLADispatchScalarArgs[num_heads: Int, _is_cache_length_accurate: Bool = False, is_fp8_kv: Bool = False]

Pre-computed scalar dispatch args for MLA decode legacy (non-capturable) path.

Holds a GPU buffer containing [batch_size, q_max_seq_len, num_partitions, max_cache_valid_length] and stores the scalar values as plain Int fields for host-side dispatch.

Usage::

var args = MLADispatchScalarArgs[num_heads=128](
    batch_size, max_cache_len, q_max_seq_len, ctx,
)
var gpu_lt = args.gpu_layout_tensor()
mla_decode_sm100_dispatch[...](
    ..., gpu_lt,
    args.batch_size, args.q_max_seq_len, args.max_cache_valid_length,
    ctx,
)
_ = args  # keepalive

Fields

  • gpu_buf (DeviceBuffer[DType.int64]):
  • batch_size (Int):
  • q_max_seq_len (Int):
  • max_cache_valid_length (Int):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = False

MLAScalarArgsLT

comptime MLAScalarArgsLT = LayoutTensor[DType.int64, Layout.row_major(4), MutAnyOrigin]

Methods

__init__

__init__(out self, batch_size: Int, max_cache_len: Int, q_max_seq_len: Int, ctx: DeviceContext)

gpu_layout_tensor

gpu_layout_tensor(self) -> MLADispatchScalarArgs[num_heads, _is_cache_length_accurate, is_fp8_kv].MLAScalarArgsLT

Returns:

MLADispatchScalarArgs

Was this page helpful?