Skip to main content
Log in

Flash attention

An optimization technique to compute attention blocks in transformer models. Traditional attention requires storing large intermediate activation tensors, leading to high memory overhead that slows execution because it requires frequent memory transfers between high-bandwidth memory (HBM) and faster SRAM on the GPU.

Flash attention improves performance and reduces the memory footprint for attention layers by reordering computations with techniques such as tiling to compute attention scores in blocks, and keeping only small chunks of activations in the faster on-chip SRAM. This allows the model to process much longer sequences without running into memory limitations.

By improving the efficiency of attention layers, flash attention enables LLMs to handle much longer contexts, improving their ability to understand and generate complex text.