Skip to main content

分布式训练学习记录

分布式训练学习记录

📄️ Flash_attention

​ 首先明确一点,FlashAttention的提出,是为了解决在计算attention矩阵时,计算较慢的问题,因为一般会把整个大矩阵加载到HBM(High Bandwidth Memory)—高带宽内存;而现在希望将其加载到SRAM(Static Random-Access Memory)—静态随机访问存储器,如下图所示,展现了其计算速度。而由于SRAM存储容量较小,但计算Attention时的显存较高,一般需要的显存为O(NN),Q,K,V∈R^N✖d^,因此,提出了FlashAttention,将Q,K,V矩阵分块*,逐块加载到GPU的SRAM中计算,并且提出了在backward过程中,利用forward过程计算得到的softmax的归一化因子,在SRAM中快速重计算注意力矩阵,而不是一致保留forward过程的中间结果,而导致显存被持续较大占用,同时将注意力操作融合到一个GPU kernel中,避免从HBM多次读取和写入;最后进一步将其扩展为稀疏版,仅计算非零块的注意力矩阵,进一步降低内存访问次数和计算复杂度。