Mojo struct
MLADispatchScalarArgs
struct MLADispatchScalarArgs[num_heads: Int, _is_cache_length_accurate: Bool = False, is_fp8_kv: Bool = False]
Pre-computed scalar dispatch args for MLA decode legacy (non-capturable) path.
Holds a GPU buffer containing
[batch_size, q_max_seq_len, num_partitions, max_cache_valid_length]
and stores the scalar values as plain Int fields for host-side dispatch.
Usage::
var args = MLADispatchScalarArgs[num_heads=128](
batch_size, max_cache_len, q_max_seq_len, ctx,
)
var gpu_lt = args.gpu_layout_tensor()
mla_decode_sm100_dispatch[...](
..., gpu_lt,
args.batch_size, args.q_max_seq_len, args.max_cache_valid_length,
ctx,
)
_ = args # keepaliveFields
- gpu_buf (
DeviceBuffer[DType.int64]): - batch_size (
Int): - q_max_seq_len (
Int): - max_cache_valid_length (
Int):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = False
MLAScalarArgsLT
comptime MLAScalarArgsLT = LayoutTensor[DType.int64, Layout.row_major(4), MutAnyOrigin]
Methods
__init__
__init__(out self, batch_size: Int, max_cache_len: Int, q_max_seq_len: Int, ctx: DeviceContext)
gpu_layout_tensor
gpu_layout_tensor(self) -> MLADispatchScalarArgs[num_heads, _is_cache_length_accurate, is_fp8_kv].MLAScalarArgsLT
Returns:
MLADispatchScalarArgs
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!