24 Star 149 Fork 127

Ascend / AscendSpeed

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
flash-attention.md 1.41 KB
一键复制 编辑 原始数据 按行查看 历史

flash attention 适配

问题分析

由于self-attention的时间计算复杂度和内存复杂度与序列长度成二次方关系,因此transformer在长序列上的处理时间、内存开销较大。近似的注意力方法可以优化这一问题,但会降低模型质量。

解决方案

加速注意力的关键在于优化IO访存,即降低HBM的读/写次数。

解决思路:

Flash Attention 是一种优化IO访存开销的精确注意力方法,原理如下图所示[1],通过Tiling切片、重计算、Kernel Fusion等方式来减少GPU高带宽内存(HBM)和GPU片上SRAM之间的内存读/写次数。

a. Tiling切片:利用更高速的SRAM代替HBM,但SRAM的内存容量较少,无法一次性完成所有数据的完整注意力计算,因此需要进行分块计算。

b. 重计算:放弃中间结果写回,需要使用时重新计算,用计算换访存。

c. Kernel Fusion:将多个操作融合为一个操作,基于Tiling利用一个kernel完成整个计算。

原文链接

使用场景

本方法适用于self-attention相关模型,尤其适用于长序列输入场景。

使用方法

设置--use-flash-attn即可调用该算法。

使用效果

在模型训练时间、模型质量等方面可以提升性能。

1
https://gitee.com/ascend/AscendSpeed.git
git@gitee.com:ascend/AscendSpeed.git
ascend
AscendSpeed
AscendSpeed
master

搜索帮助