采用不同的head數(shù)的參數(shù)量不變:
因?yàn)樵趇n-projection以及out-projection時(shí)是不分組的舒憾,和head數(shù)無關(guān)攒钳。
在計(jì)算attention時(shí),是分組的克胳,將embed_dim劃分成了num_heads組,每組分別計(jì)算attention后再拼接圈匆。
假設(shè)采用self attention漠另,即q==k==v,其shape都是(L,1,E)跃赚,那么MHA的計(jì)算量為:
- in-projection
L*E*E - attention
令E_2 = E // num_heads笆搓,
則劃分為了num_heads組,每組計(jì)算量為L*L*E2,總計(jì)算量為L*L*E砚作。 - out-projection
L*E*E
由此可見窘奏,計(jì)算量也與num_heads無關(guān)。
參考:https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4953