For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MaskedFlashAttentionGPU
struct MaskedFlashAttentionGPU
Implemented traitsโ
AnyType,
ImplicitlyDestructible
Methodsโ
executeโ
static def execute[target: StringSlice[StaticConstantOrigin], rank: Int](output: ManagedTensorSlice[Output, static_spec=output.static_spec], q: ManagedTensorSlice[Input, static_spec=q.static_spec], k: ManagedTensorSlice[Input, static_spec=k.static_spec], v: ManagedTensorSlice[Input, static_spec=v.static_spec], mask: ManagedTensorSlice[Input, static_spec=mask.static_spec], scale: Float32, ctx: DeviceContext)
masked_flash_attention_gpu is a hand-fused operator which does something analogous to the following list of operations.
**Step 0: Transpose: query_processed = transpose(query) # BSHD --> BHSD key_processed = transpose(key) # BSHD --> BHDS value_processed = transpose(value) # BSHD --> BHSD
**Step 1: attentionMatrix = query_processed @ key_processed
**Step 2: norm = broadcast_to(normScalar, shape_of(attentionMatrix))
**Step 3:
Normalize and apply masking
attentionMatrixNorm = attentionMatrix * scale
Note attention_mask is HSS and auto-broadcasts
attentionMatrixNormMasked = attentionMatrixNorm + attention_mask
**Step 4:
Apply softmax and reproject result
attentionMatrixSoftMax = softmax(attentionMatrixNormMasked) answer = attentionMatrixSoftMax @ value_processed answer = transpose(answer) # BHSD --> BSHD
Compared to the CPU patterns the notable differences are:
- The mask is rank 3 and is of shape BSS
- The transposes are part of the kernel itself
Finally, this pattern supports grouped attention patterns. That is if we have G groups, then let h = H / G. Key and value are allowed to be BShD in these scenarios. Both key and value must be BShD if one is. If this is true the following is equivalently run before Step 0:
** Step -1: key = concat(key, ...) # concat BShD --> BSHD value = concat(value, ...) # concat BShD --> BSHD
The underlying fusion follows ideas taken from the 2022 FlashAttention paper by Tri Dao et al.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!