For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
RaggedFlashAttentionGPU
struct RaggedFlashAttentionGPU
Implemented traitsโ
AnyType,
ImplicitlyDestructible
Methodsโ
executeโ
static def execute[rank: Int, //, target: StringSlice[StaticConstantOrigin], mask_str: StringSlice[StaticConstantOrigin], local_window_size: Int = -1](output: ManagedTensorSlice[Output, static_spec=output.static_spec], q: ManagedTensorSlice[Input, static_spec=q.static_spec], k: ManagedTensorSlice[Input, static_spec=k.static_spec], v: ManagedTensorSlice[Input, static_spec=v.static_spec], input_row_offsets: ManagedTensorSlice[Input, static_spec=input_row_offsets.static_spec], q_max_seq_len: ManagedTensorSlice[Input, static_spec=q_max_seq_len.static_spec], scale: Float32, ctx: DeviceContext)
mo.mha.ragged.no_cache computes flash attention for ragged inputs without KV cache.
The inputs q, k, v are in ragged format with shape [total_seq_len, num_heads, head_dim]. input_row_offsets indicates where each sequence starts and ends in the ragged tensors.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!