IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

RaggedFlashAttentionGPU

struct RaggedFlashAttentionGPU

Implemented traitsโ€‹

AnyType, ImplicitlyDestructible

Methodsโ€‹

executeโ€‹

static def execute[rank: Int, //, target: StringSlice[StaticConstantOrigin], mask_str: StringSlice[StaticConstantOrigin], local_window_size: Int = -1](output: ManagedTensorSlice[Output, static_spec=output.static_spec], q: ManagedTensorSlice[Input, static_spec=q.static_spec], k: ManagedTensorSlice[Input, static_spec=k.static_spec], v: ManagedTensorSlice[Input, static_spec=v.static_spec], input_row_offsets: ManagedTensorSlice[Input, static_spec=input_row_offsets.static_spec], q_max_seq_len: ManagedTensorSlice[Input, static_spec=q_max_seq_len.static_spec], scale: Float32, ctx: DeviceContext)

mo.mha.ragged.no_cache computes flash attention for ragged inputs without KV cache.

The inputs q, k, v are in ragged format with shape [total_seq_len, num_heads, head_dim]. input_row_offsets indicates where each sequence starts and ends in the ragged tensors.