Mojo struct
MLADispatchScalarArgs
struct MLADispatchScalarArgs[num_heads: Int, _is_cache_length_accurate: Bool = False, is_fp8_kv: Bool = False]
Pre-computed MLA decode args for the legacy (non-capturable) path.
Owns a GPU buffer containing [batch_size, q_max_seq_len, num_partitions]
and caches the host-side batch_size/q_max_seq_len pair needed by
mla_decode_sm100_dispatch.
Usage::
var args = MLADispatchScalarArgs[num_heads=128](
batch_size, max_cache_len, q_max_seq_len, ctx,
)
var gpu_lt = args.gpu_layout_tensor()
mla_decode_sm100_dispatch[...](
..., gpu_lt,
args.batch_size, args.q_max_seq_len, max_cache_len,
ctx,
)
_ = args # keepaliveFields
- gpu_buf (
DeviceBuffer[DType.int64]): - batch_size (
Int): - q_max_seq_len (
Int):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
MLAScalarArgsLT
comptime MLAScalarArgsLT = LayoutTensor[DType.int64, Layout.row_major(VariadicList(3)), MutAnyOrigin]
Methods
__init__
__init__(out self, batch_size: Int, max_cache_len: Int, q_max_seq_len: Int, ctx: DeviceContext)
gpu_layout_tensor
gpu_layout_tensor(self) -> MLADispatchScalarArgs[num_heads, _is_cache_length_accurate, is_fp8_kv].MLAScalarArgsLT
Returns:
MLADispatchScalarArgs
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!