Mojo struct
MLADispatchScalarArgs
struct MLADispatchScalarArgs[num_heads: Int, _is_cache_length_accurate: Bool = False, is_fp8_kv: Bool = False]
Pre-computed MLA decode args for the legacy (non-capturable) path.
Owns a GPU buffer containing [batch_size, q_max_seq_len, num_partitions]
and caches the host-side batch_size/q_max_seq_len pair needed by
mla_decode_sm100_dispatch.
Usage::
var args = MLADispatchScalarArgs[num_heads=128](
batch_size, max_cache_len, q_max_seq_len, ctx,
)
var gpu_lt = args.gpu_layout_tensor()
mla_decode_sm100_dispatch[...](
..., gpu_lt,
args.batch_size, args.q_max_seq_len, max_cache_len,
ctx,
)
_ = args # keepaliveFieldsβ
- βgpu_buf (
DeviceBuffer[DType.int64]): - βbatch_size (
Int): - βq_max_seq_len (
Int):
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
MLAScalarArgsLTβ
comptime MLAScalarArgsLT = LayoutTensor[DType.int64, Layout.row_major(3), MutAnyOrigin]
Methodsβ
__init__β
__init__(out self, batch_size: Int, max_cache_len: Int, q_max_seq_len: Int, ctx: DeviceContext)
gpu_layout_tensorβ
gpu_layout_tensor(self) -> MLADispatchScalarArgs[num_heads, _is_cache_length_accurate, is_fp8_kv].MLAScalarArgsLT
Returns:
MLADispatchScalarArgs[num_heads, _is_cache_length_accurate, is_fp8_kv].MLAScalarArgsLT
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!