In order to understand how flash attention and its variants help improve compute efficiency of modern LLMs training, we first have to dive deep into GPU compute model and its memory hierarchy.
GPU Compute Model and Memory Hierarchy
The Figure 1 here shows the high level compute model and memory in GPU. We can see that there are three types of memory affect GPU computation. CPU memory (data loading etc), GPU high bandwidth memory (the gpu memory we usually mentioned), and GPU caches (SRAM). These memories are of different size and bandwidth (read speed). The idea of flash attention is to design IO-aware fused computation kernel to save memory access to speed up training job.
Figure 1. GPU memory
Figure 2 shows a more detailed hierarchy of GPU memory in A100. Notice that cache is specific to each compute unit.
Figure 2. GPU memory hierarchy
IO-aware Computation
First let’s take a look at the vallina attention computation which is shown below
Figure 3. Vallina attention computation
Essentially, each of the operation follows the three steps of operation below.
- Read op — Move tensor from HBM to SRAM
- Compute op - Perform compute intensive task on SRAM
- write op - move tensor back from SRAM to HBM
The breakdown of these computation is as follows. Apparently, all these green ops in the vallina attention can be saved.
Figure 4. Vallina attention computation break down
However, it’s hard to put giant attention matrix of size [N x N]
in the cache. The idea to solve this challenge is to use tiling. Concretely, we slice the matrices into smaller blocks and in each of Q K computation, we do it in a small block scale. The output of the small block thus can be saved on the cache. This sounds perfectly except that softmax op is not possible with small block computation. Lucklily there are already some studies dealing with this [1-2]. Before talking about this, let’s first revisit stable softmax computation.
Blockwise Softmax
Underflow in numerical computation can cause precision issue. Overflow can be more problematic because it usually leads to divergence of training job (some may argue silent error is more detrimental :)). Softmax operation involves exponential computation which without careful handling can easily lead to overflow (such as exp(2000)
).
$$ \text{softmax}(x)_i = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} $$
Similary, the cross entropy can be computed as
$$ \begin{aligned} H(p, q) &= -\sum_i p_i\log(q_i) \\ &= -1\cdot\log(q_y) -\sum_{i \neq y} 0\cdot\log(q_i) \\ &= -\log(q_y) \\ &= -\log(\text{softmax}(\hat{y})_y) \\ \end{aligned} $$
When $max(x)$ is very large, the numerator could become $0$, and $log$ computation could overflow. To prevent this, we can do one more step: $$ \begin{aligned} \log(\text{softmax}(x)_i) &= \log(\frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}) \\ &= x_i - \max(x) - \log(\sum_j e^{x_j - \max(x)}) \end{aligned} $$
By simply extracting the max value, we limit the exponential values to be in [0, 1]. In Flashattention paper, the softmax is represented as follows:
Figure 5. Softmax
Then blockwise softmax can be computed as follows:
Figure 6. Blockwise Softmax
With saving some summary (i.e. max) statistics, the softmax op can be decomposed into blocks.
Recomputation in Backpropagation
With the fused kernel, we effectively do the computation outside Pytorch computation graph. Thus, we can’t use the AutoGrad for gradient computation in backpropagation. Consequently, we have to define the backpropagation by ourselves. The way to solve this is very simple as well. We just define our own backpropagation ops for fused kernel like gradient checkpointing.
References
[1] SELF-ATTENTION DOES NOT NEED O(n^2) MEMORY
[2] Online normalizer calculation for softmax
[3] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
...