IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

FlashAttentionGPU

struct FlashAttentionGPU

Implemented traitsโ€‹

AnyType, ImplicitlyDestructible

Methodsโ€‹

executeโ€‹

static def execute[rank: Int, //, target: StringSlice[StaticConstantOrigin], mask_str: StringSlice[StaticConstantOrigin], local_window_size: Int = -1](output: ManagedTensorSlice[Output, static_spec=output.static_spec], q: ManagedTensorSlice[Input, static_spec=q.static_spec], k: ManagedTensorSlice[Input, static_spec=k.static_spec], v: ManagedTensorSlice[Input, static_spec=v.static_spec], scale: Float32, ctx: DeviceContext)

mo.mha.no_cache is a hand-fused operator which does something analogous to the following list of operations.

**Step 0: Transpose: query_processed = transpose(query) # BSHD --> BHSD key_processed = transpose(key) # BSHD --> BHDS value_processed = transpose(value) # BSHD --> BHSD

**Step 1: attentionMatrix = query_processed @ key_processed

**Step 2: norm = broadcast_to(normScalar, shape_of(attentionMatrix))

**Step 3:

Normalize and apply masking

attentionMatrixNormMasked = mask_functor(attentionMatrix * scale)

**Step 4:

Apply softmax and reproject result

attentionMatrixSoftMax = softmax(attentionMatrixNormMasked) answer = attentionMatrixSoftMax @ value_processed answer = transpose(answer) # BHSD --> BSHD

Compared to the CPU patterns the notable differences are:

  1. The transposes are part of the kernel itself

Finally, this pattern supports grouped attention patterns. That is if we have G groups, then let h = H / G. Key and value are allowed to be BShD in these scenarios. Both key and value must be BShD if one is. If this is true the following is equivalently run before Step 0:

** Step -1: key = concat(key, ...) # concat BShD --> BSHD value = concat(value, ...) # concat BShD --> BSHD

The underlying fusion follows ideas taken from the 2022 FlashAttention paper by Tri Dao et al.