In order to understand how flash attention and its variants help improve compute efficiency of modern LLMs training, we first have to dive deep into GPU compute model and its memory hierarchy.

GPU Compute Model and Memory Hierarchy

The Figure 1 here shows the high level compute model and memory in GPU. We can see that there are three types of memory affect GPU computation. CPU memory (data loading etc), GPU high bandwidth memory (the gpu memory we usually mentioned), and GPU caches (SRAM). These memories are of different size and bandwidth (read speed). The idea of flash attention is to design IO-aware fused computation kernel to save memory access to speed up training job.

gpu memory Figure 1. GPU memory

Figure 2 shows a more detailed hierarchy of GPU memory in A100. Notice that cache is specific to each compute unit.

gpu memory hierarchy Figure 2. GPU memory hierarchy

IO-aware Computation

First let’s take a look at the vallina attention computation which is shown below

Vallina attention algorithm Figure 3. Vallina attention computation

Essentially, each of the operation follows the three steps of operation below.

  • Read op — Move tensor from HBM to SRAM
  • Compute op - Perform compute intensive task on SRAM
  • write op - move tensor back from SRAM to HBM

The breakdown of these computation is as follows. Apparently, all these green ops in the vallina attention can be saved.

Vallina attention algorithm Figure 4. Vallina attention computation break down

However, it’s hard to put giant attention matrix of size [N x N] in the cache. The idea to solve this challenge is to use tiling. Concretely, we slice the matrices into smaller blocks and in each of Q K computation, we do it in a small block scale. The output of the small block thus can be saved on the cache. This sounds perfectly except that softmax op is not possible with small block computation. Lucklily there are already some studies dealing with this [1-2]. Before talking about this, let’s first revisit stable softmax computation.

Blockwise Softmax

Underflow in numerical computation can cause precision issue. Overflow can be more problematic because it usually leads to divergence of training job (some may argue silent error is more detrimental :)). Softmax operation involves exponential computation which without careful handling can easily lead to overflow (such as exp(2000)).

$$ \text{softmax}(x)_i = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} $$

Similary, the cross entropy can be computed as

$$ \begin{aligned} H(p, q) &= -\sum_i p_i\log(q_i) \\ &= -1\cdot\log(q_y) -\sum_{i \neq y} 0\cdot\log(q_i) \\ &= -\log(q_y) \\ &= -\log(\text{softmax}(\hat{y})_y) \\ \end{aligned} $$

When $max(x)$ is very large, the numerator could become $0$, and $log$ computation could overflow. To prevent this, we can do one more step: $$ \begin{aligned} \log(\text{softmax}(x)_i) &= \log(\frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}) \\ &= x_i - \max(x) - \log(\sum_j e^{x_j - \max(x)}) \end{aligned} $$

By simply extracting the max value, we limit the exponential values to be in [0, 1]. In Flashattention paper, the softmax is represented as follows:

softmax Figure 5. Softmax

Then blockwise softmax can be computed as follows:

blockwise softmax Figure 6. Blockwise Softmax

With saving some summary (i.e. max) statistics, the softmax op can be decomposed into blocks.

Recomputation in Backpropagation

With the fused kernel, we effectively do the computation outside Pytorch computation graph. Thus, we can’t use the AutoGrad for gradient computation in backpropagation. Consequently, we have to define the backpropagation by ourselves. The way to solve this is very simple as well. We just define our own backpropagation ops for fused kernel like gradient checkpointing.

References

[1] SELF-ATTENTION DOES NOT NEED O(n^2) MEMORY
[2] Online normalizer calculation for softmax
[3] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness