Attention_Head(MHA、MQA、GQA、MLA)

多头注意力机制包含MHA、GQA、MQA、MLA以上四种。
1.MHA

最早提出的是MHA,即Multi-Head Attention,每一个Query头对应一组K-V对,在推理过程中,计算Attention时,每次都要计算历史的QKV矩阵,速度较慢,为了加速,提出了KV Cache的策略,但是KV Cache会随着推理的token增加,逐渐增大,占用显存越来越大,因此就提出了其它的方法。

注意Q,K,V的维度此时是dh,而后续的MQA,GQA并不会改变Q,K,V的维度,只是通过共享操作连减少计算量。
2.MQA

在此基础上,既然KV cache如此占显存,那么久减少KV对的使用,即使用一个共享版本,Multi-Query Attention。可以看到,此时每个head的Query都共享K和V矩阵,则KV cache的显存占用直接降低到了1/n,然后这样做容易导致模型的性能下降,严重的还会导致模型的稳定性,因此又选出了折中的方法。
常用模型包括PALM、StarCoder、Gemini等。
3.GQA

GQA是在使用group_nums对Query进行分组,g=1就是MQA,g=n就是MHA。常用模型包括LLAMA2-70B,LLAMA3全系列,DeepSeek-V1,Qwen3-MOE系列,ChatGLM3等。
注意,这里是一组K,V对,对应多个Q
4.MLA

如图是MLA的一个总览图,先在此做一个总结:
1.MLA整个过程包含MHA和MQA,MQA体现在对于position embedding的处理,所有的Q共用同一个RoPE的K。因此图中的会被复制或者说广播和其它进行拼接,后续都是一个Q对应一组KV,也就是MHA。
2.MLA整个过程缓存的是和:
其实是对输入x或者上一层输入的隐藏状态做一个降维的操作,即乘以一个降维矩阵,后续要获取、仅需要通过对这个降维后的矩阵做一个升维的操作即可,而这个过程其实是对原来的,做了降秩的操作。即原来计算量是,现在是。
则是对输入x或者上一层输入的隐藏状态做一个降维的操作,并接着做RoPE,文中降低到了,作者认为仅用这么高的维度即可表现位置信息。
3.MLA仅在推理的Decoding阶段使用
4.从上述结论也可以看出,所有的input_embedding和position_embedding是分为两份矩阵表示,并且每个头分别拼接这两份向量,组成最终的Q,K也同样,组合后,再做attention。
5.MLA的操作可以总结为:
1.矩阵吸收(计算图优化)——利用矩阵乘法的链路最优解,将self_attention以及输出矩阵O的整个计算结合,找出最优计算次数最少的矩阵乘法方式。
2.低秩投影,引入降秩的概念,用两个d*r的矩阵将原来的d*d的矩阵给降秩,从而大大降低计算量。
3.如果仅仅引入计算图优化,相比于KV Cache,显存的确降低了, 但是也增加了一些计算量,即时间换空间,所以引入了降秩的概念,进一步降低显存,并且计算量的增加也会减少很多。1,2的引入,使得MLA可以保存比KV Cache占用显存更小的和,从而解决KV Cache的显存瓶颈问题。
1.MLA概念分析

元素推导

元素计算流程
整个Decoding计算流程如下:
在推理每一个token时,仍然只需要输入计算出的结果的最后一个隐藏层状态,并且对于Q来说,首先与相乘,再乘一个相同大小的矩阵,这个操作实际就是降秩,减少计算量,最后得到的Q的大小仍然没变,与传统方法不同的是,这里会将context与position信息解耦,即不再在计算出的q矩阵上去做RoPE,而是利用同样的降秩操作,最后得到一个hidden_size较小的矩阵,并按照head_nums做分割,分配每个head,然后分别对每个小矩阵做RoPE,最终每个attention头都得到了带有position信息的矩阵,然后将context矩阵和position矩阵进行拼接,从而得到q。
对于k和v,同样会用降秩的思想,即分别乘以一个和从而实现减少计算量,而对于K来说,同样会将context与position信息解耦,但是注意这里计算时,即K的position信息的矩阵,其维度是的维度 ,因为这里的其实是所有attention头共享的,即MQA的思想。最后计算出,同样将context矩阵和position矩阵进行拼接,从而得到k。
通过这样的计算过程可以看到,此时我们只需要保存和这两部分数据,即可在每次forward过程迅速恢复K,V的信息,并且相比于直接保存K,V,整个显存也大大降低了,原来是,而现在变成了,其中与都非常小。
但仅仅如此,会发现我们仍然要根据把K,V这两个矩阵给重新计算出来,仍然存在较大的计算量,那如何优化这一点呢,这就是上面提到的,灰色箭头,实际计算并没有按照灰色箭头的方向,逐渐把K,V矩阵算出来,而是把