Flash Attention 学习笔记
Community Article
Published
August 26, 2024
概览
- 论文 Fast and Memory-Efficient Exact Attention with IO-Awareness
- github Dao-AILab/flash-attention
- 优化效果
- 训练速度提升了 2 到 4 倍;
- 训练时显存占用随序列长度平方增涨优化成线性增涨。
- 优化思路
- fusion 融合计算,节省了多个操作之间存取 HBM 的时间。
- 融合计算不保存中间结果,但后向传播计算梯度需要用中间结果,怎么办?重计算
- 相关知识点
- Attention 的标准计算过程;
- Pytorch 中 Attention 计算过程的实现;
- 制约训练速度的主要瓶颈,Compute-Bound、Memory-Bound;Attention 计算瓶颈属于 Memory-Bound ;
- 显存内的缓存分级;芯片内(SRAM)、芯片外(HBM)、CPU(DRAM),读取速度依次递减、显存大小依次递增。
- Flash Attention 实现关键点
- 通过分块计算,融合多个操作,减少中间结果存取。其中,Softmax 的分块计算较复杂;
- 反向传播时,重新计算中间结果。
详细展开
... 待续