關(guān)鍵詞:位置編碼
,RoPE
,Transformer
迄埃,大語(yǔ)言模型
前言
旋轉(zhuǎn)位置編碼RoPE(Rotary Position Embedding)是一種Transformer模型中的位置編碼策略蟀架,它廣泛應(yīng)用于LLama,ChatGLM等大模型栏妖,本篇先介紹RoPE的實(shí)現(xiàn)步驟和源碼乱豆,再深入講解RoPE涉及到的數(shù)學(xué)原理,力求做到從易到難吊趾,學(xué)習(xí)曲線平滑宛裕。
內(nèi)容摘要
- 位置編碼知識(shí)準(zhǔn)備
- 旋轉(zhuǎn)位置編碼的本質(zhì)和計(jì)算流程
- 旋轉(zhuǎn)位置編碼如何表達(dá)相對(duì)位置信息
- 旋轉(zhuǎn)位置編碼的源碼分析
- 旋轉(zhuǎn)位置編碼的推導(dǎo)
位置編碼知識(shí)準(zhǔn)備
由于Transformer的Self Attention具有排列不變性,因此需要通過引入位置編碼來讓模型感知到輸入序列中每個(gè)單詞的位置信息论泛,位置編碼分為絕對(duì)位置編碼和相對(duì)位置編碼揩尸。
絕對(duì)位置編碼根據(jù)單個(gè)單詞
的絕對(duì)位置來定義位置編碼,每個(gè)位置都會(huì)分配一個(gè)位置編碼屁奏,將位置編碼的表征和單詞本身的表征進(jìn)行融合岩榆,再輸入給Self Attention,相當(dāng)于在輸入層就把位置信息給彌補(bǔ)上去。絕對(duì)位置編碼從實(shí)現(xiàn)方式上又分為固定式和可學(xué)習(xí)式勇边,固定式形如原生的Transformer所采用的三角sin-cos位置編碼犹撒,所謂固定指的是根據(jù)一個(gè)無(wú)參的固定公式就可以推演出位置編碼,而可學(xué)習(xí)式?jīng)]有固定的位置編碼公式粒褒,通過初始化位置向量讓模型根據(jù)上下文數(shù)據(jù)自適應(yīng)地學(xué)習(xí)出來识颊,Bert和GPT采用的可學(xué)習(xí)式。
相對(duì)位置編碼對(duì)兩個(gè)單詞之間
的相對(duì)位置進(jìn)行建模奕坟,并且將相對(duì)位置信息加入到Self Attention模型結(jié)構(gòu)中祥款,形如Transformer-XL,DeBERTa等采用的就是相對(duì)位置編碼月杉。Self Attention的本質(zhì)是兩個(gè)單詞信息的內(nèi)積操作刃跛,相對(duì)位置編碼的思想是對(duì)內(nèi)積的計(jì)算方式進(jìn)行改進(jìn),在內(nèi)積中注入兩個(gè)單詞的相對(duì)位置因素苛萎。
旋轉(zhuǎn)位置編碼的本質(zhì)和計(jì)算流程
旋轉(zhuǎn)位置編碼RoPE是一種固定式
的絕對(duì)位置編碼
策略桨昙,但是它的絕對(duì)位置編碼配合Transformer的Attention內(nèi)積注意力機(jī)制能達(dá)到相對(duì)位置編碼
的效果。RoPE的本質(zhì)是對(duì)兩個(gè)token形成的Query和Key向量做一個(gè)變換
首懈,使得變換后的Query和Key帶有位置信息绊率,進(jìn)一步使得Attention的內(nèi)積操作不需要做任何更改
就能自動(dòng)感知到相對(duì)位置信息。換句話說究履,RoPR的出發(fā)點(diǎn)和策略用的相對(duì)位置編碼思想滤否,但是實(shí)現(xiàn)方式的確用的是絕對(duì)位置編碼。
固定式表明RoPE沒有額外需要模型自適應(yīng)學(xué)習(xí)的參數(shù)最仑,因此RoPE是一種高效的編碼方式藐俺。絕對(duì)位置編碼表明RoPE給文本的每個(gè)位置單詞都分配了一個(gè)位置表征,和三角sin-cos位置編碼一樣泥彤,RoPE通過token在句子中的位置欲芹,token embedding中每個(gè)元素的位置,這兩個(gè)要素一起確定位置編碼的表達(dá)吟吝,先給出RoPE的公式如下
RoPE有一定數(shù)學(xué)推導(dǎo)環(huán)節(jié)菱父,但是最終的公式并不復(fù)雜,因此本篇先從RoPE公式入手介紹RoPE在做什么剑逃,該公式是將一個(gè)原始的token向量改造為一個(gè)注入位置信息之后的新向量的過程浙宜。
其中第一項(xiàng)代表某個(gè)位置為m的token的原始Query向量,0~d-1代表向量每個(gè)位置的元素蛹磺,d代表向量的維度粟瞬,第二項(xiàng)為一個(gè)同樣長(zhǎng)度是d的帶有cos三角函數(shù)的向量,它和Query向量逐位相乘萤捆,第三項(xiàng)由原始Query變換而來裙品,第四項(xiàng)和第二項(xiàng)類似區(qū)別是將cos替換為sin俗批。
該公式的目的是將原始Query向量改造成一個(gè)帶有位置信息的新向量,位置信息由參數(shù)m和θ進(jìn)行表征市怎,其中m為token在句子中的位置岁忘,θ的下標(biāo)和向量中各元素的位置直接相關(guān),公式如下
因此只要給到某個(gè)token的輸入Query向量焰轻,知道token在上下文窗口下處于第幾位臭觉,就可以將它的Query向量通過RoPE的公式改造為一個(gè)新的向量形式,新形成的向量和原向量維度完全一致辱志。以“我愛你”這句話中的第二個(gè)詞“愛”為例,設(shè)詞向量的維度d=4狞膘,詞向量表征為[0.2, 0.1, -0.3, 0.7]揩懒,則經(jīng)過RoPE變化的計(jì)算示意圖如下
公式中的第三項(xiàng)由原始向量變換而來,對(duì)于原始輸入向量挽封,將前后兩個(gè)元素位置構(gòu)成一對(duì)
已球,交換兩者的位置,并且對(duì)于偶數(shù)位取了相反數(shù)辅愿,因此每個(gè)元素位的注入位置信息的過程智亮,可以看成是該元素和它相鄰的元素,分別經(jīng)過sin点待,cos三角函數(shù)加權(quán)求和的結(jié)果阔蛉,比如q0的RoPE結(jié)果是q0和q1這一對(duì)元素經(jīng)過三角函數(shù)變換的結(jié)果。在下文的源碼分析中癞埠,我們會(huì)介紹此處的相鄰條件并不是必須的
状原,而是任意不重復(fù)的一對(duì)都滿足這個(gè)變換性質(zhì)
。
在Transformer原生的三角sin-cos位置編碼中苗踪,采用相加的形式將位置編碼融入到詞向量中颠区,而在RoPE中采用的是類似哈達(dá)馬積的乘積形式,讀者可以將以上RoPE公式做的事情類比于Transformer中原始向量表征和sin-cos位置編碼相加的過程通铲。
旋轉(zhuǎn)位置編碼如何表達(dá)相對(duì)位置信息
在之前介紹的sin-cos位置編碼中Transformer系列:快速通俗理解Transformer的位置編碼毕莱,我們知道sin-cos位置編碼因?yàn)槿呛瘮?shù)的性質(zhì),使得它可以表達(dá)相對(duì)位置信息颅夺,具體而言是:給定距離朋截,任意位置的位置編碼都可以表達(dá)為一個(gè)已知位置的位置編碼的關(guān)于距離的線性組合,而RoPE的位置編碼也是同樣的思路碗啄,采用絕對(duì)位置編碼實(shí)現(xiàn)相對(duì)距離的表達(dá)质和,區(qū)別如下
- 實(shí)現(xiàn)相對(duì)位置能力的途徑不同:sin-cos位置編碼由于三角函數(shù)的性質(zhì),導(dǎo)致它本身就具備表達(dá)相對(duì)距離的能力稚字,而RoPE位置編碼本身不能表達(dá)相對(duì)距離饲宿,需要結(jié)合Attention的內(nèi)積才能激發(fā)相對(duì)距離的表達(dá)能力
- 和原輸入的融合計(jì)算方式不同:sin-cos位置編碼直接和原始輸入相加厦酬,RoPE位置編碼采用類似哈達(dá)馬積相乘的形式
在知識(shí)準(zhǔn)備模塊我們介紹的相對(duì)位置編碼,其主要的思想是原始輸入不變瘫想,將相對(duì)位置信息注入Attention模塊仗阅,采用對(duì)Attention的網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行修改方式,將位置表征因素也額外的加入Attention計(jì)算国夜,使得Attention模塊能夠把輸入層丟失的位置信息彌補(bǔ)回來减噪。
RoPE參考相對(duì)位置編碼的思想,它也是在Attention模塊讓模型感知到相對(duì)位置车吹,但是它是不改變Attention的結(jié)構(gòu)
筹裕,反而像絕對(duì)位置編碼一樣在輸入層做文章,對(duì)輸入向量做改造窄驹,改造后Attention模塊能夠重新感知到相對(duì)位置朝卒,同樣能把位置信息彌補(bǔ)回來,因此RoPE可是說是使用絕對(duì)位置編碼的方式實(shí)現(xiàn)了相對(duì)位置編碼乐埠,是兩者的融合
抗斤。
至于為什么RoPE可以通過Attention來激發(fā)相對(duì)位置信息,原因是帶有RoPE位置編碼兩個(gè)token丈咐,它們形成的Quey向量和Key向量進(jìn)入Self Attention層之后瑞眼,Attention內(nèi)積的結(jié)果可以恒等轉(zhuǎn)化一個(gè)函數(shù),該函數(shù)只和Quey向量棵逊,Key向量伤疙,以及兩個(gè)token位置之差有關(guān)
,細(xì)節(jié)推導(dǎo)將在下文的進(jìn)行介紹歹河,讀者先對(duì)這個(gè)結(jié)論有個(gè)初步印象掩浙。
旋轉(zhuǎn)位置編碼的源碼分析
在前文已經(jīng)通過公式和一個(gè)具體的例子說明了RoPE的計(jì)算方式,下面結(jié)合HuggingFace的LLaMA大模型實(shí)現(xiàn)類LlamaForCausalLM
中RoPE的源碼再鞏固一下秸歧。先給到源碼實(shí)現(xiàn)的步驟厨姚,分為三步
-
初始化cos向量和sin向量
:根據(jù)給定的上下文窗口大小作為m,多頭下每個(gè)頭的向量的維度大小作為d键菱,生成cos向量和sin向量谬墙,也就是RoPE公式中的第二項(xiàng)和第四項(xiàng)。在LLaMA2中上下文窗口為m=4096经备,每個(gè)頭下的向量維度為d=128拭抬。
-
-
截取對(duì)應(yīng)長(zhǎng)度的cos向量和sin向量
:根據(jù)輸入Query的實(shí)際長(zhǎng)度,截取步驟一中生成的cos向量和sin向量侵蒙,例如上下文窗口為4096造虎,但是實(shí)際輸入句子長(zhǎng)度僅為10,則截取出前10個(gè)位置的cos向量和sin向量纷闺。
-
- 3.
使用cos向量和sin向量改造Query和Key
:根據(jù)步驟二產(chǎn)出的cos向量和sin向量,套用RoPE的公式,對(duì)原始Query和Key分別計(jì)算出注入位置信息之后的Query和Key六剥。
我們順著這三個(gè)步驟查看LlamaForCausalLM中RoPE的實(shí)現(xiàn),RoPE在Attention操作類LlamaAttention
中實(shí)現(xiàn)
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig):
...
# 步驟一:初始化
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
def forward(...):
...
# 步驟二:截取長(zhǎng)度
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# 步驟三:改造Query婚夫,Key
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
...
最關(guān)鍵的三行代碼分別對(duì)應(yīng)步驟一二三,在LlamaAttention的初始化模塊通過LlamaRotaryEmbedding
子模塊實(shí)現(xiàn)對(duì)RoPE的初始化署鸡,具體為對(duì)公式中的第二項(xiàng)cos向量和第四項(xiàng)sin向量進(jìn)行初始化案糙。
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
# TODO dim=128, max_position_embeddings=4096, 遠(yuǎn)程衰減
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
# 4096
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
# [4096, 64] => [4096, 128]
emb = torch.cat((freqs, freqs), dim=-1)
# TODO [1, 1, 4096, 128]
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
由于第二項(xiàng)和第四項(xiàng)僅僅是三角函數(shù)不同,三角函數(shù)的右側(cè)參數(shù)是相同的靴庆,都是mθ时捌,因此只需要將所有的mθ生成好,再對(duì)結(jié)果分別取cos和sin即可炉抒。在實(shí)現(xiàn)上作者通過m向量和θ向量的笛卡爾積相乘構(gòu)造出來了mθ組合矩陣匣椰,核心代碼為以下5行,freqs即為mθ的組合結(jié)果
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
以m=4096端礼,θ=128為例,可以通過m和θ的羅列將這個(gè)過程展現(xiàn)出來入录,每個(gè)格子中的結(jié)果為m和θ相乘的結(jié)果
θ只生成了64種情況蛤奥,作者將兩個(gè)freqs在θ拼接,形成了最終的128種情況僚稿,代碼備注中作者說這個(gè)地方和論文的公式不一樣凡桥,但是最終的效果是相同的,不一樣體現(xiàn)在θ下標(biāo)的排布順序和論文公式不一樣
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
因此最終的mθ組合為一個(gè)[4096蚀同,128]的二維矩陣缅刽,模擬如下
緊接著作者分別用cos和sin生成了兩個(gè)結(jié)果向量,并且將它們從二維矩陣變成了四維蠢络,原因是在多頭注意力中衰猛,Query和Key都是四維的形式存在,分別是[batch_size, num_heads, seq_len, head_dim]
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
初始化完畢之后刹孔,在LlamaRotaryEmbedding的forward階段根據(jù)seq_len完成截取操作啡省,對(duì)第三維就是上下文窗口m這個(gè)維度進(jìn)行截取
def forward(self, x, seq_len=None):
...
return (
# TODO [1, 1, seq_len, emb_size=128]
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
其中seq_len為輸入文本的實(shí)際長(zhǎng)度,在調(diào)用的時(shí)候它等于Key向量的實(shí)際長(zhǎng)度髓霞,如果每次輸入的是一部分token卦睹,有前文past_key_value狀態(tài),則文本長(zhǎng)度會(huì)和之前進(jìn)行拼接相加方库,最終得到的cos结序,sin就是截取之后公式中的第二項(xiàng)和第四項(xiàng)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# [batch_size, num_headsm, kv_seq_len, head_dim] => kv_seq_len
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
進(jìn)入步驟三,將原始的Query纵潦,Key向量徐鹤,cos垃环,sin輸入到apply_rotary_pos_emb中,輸出的query_states, key_states就是注入位置信息之后的Query凳干,Key向量結(jié)果
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
在apply_rotary_pos_emb中出現(xiàn)了RoPE公式晴裹,第一項(xiàng)為Query,第二項(xiàng)為cos向量救赐,第三項(xiàng)通過rotate_half方法對(duì)Query進(jìn)行變換涧团,第四項(xiàng)為sin向量,通過逐位相乘再相加的形式得到結(jié)果经磅,分別對(duì)Query和Key用同樣的方式進(jìn)行改造
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
...
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
進(jìn)一步看rotate_half是否和論文公式中給定的變換一致泌绣,答案是否定的,而在前文中對(duì)于cos和sin向量的實(shí)現(xiàn)和論文也不一致预厌,這兩處代碼的不一致恰好使得最終的效果和論文一致
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
# TODO 前64個(gè)embedding位置 x=[batch_size, num_heads, seq_len, emb_size] => [batch_size, num_heads, seq_len, emb_size/2]
x1 = x[..., : x.shape[-1] // 2]
# TODO 后64個(gè)embedding位置 x=[batch_size, num_heads, seq_len, emb_size] => [batch_size, num_heads, seq_len, emb_size/2]
x2 = x[..., x.shape[-1] // 2 :]
# TODO 后64embedding位置取負(fù)號(hào)阿迈,和前64embedding位置拼接
return torch.cat((-x2, x1), dim=-1)
HuggingFace的代碼邏輯它實(shí)現(xiàn)的計(jì)算公式實(shí)際為
該公式和RoPE論文公式在第二,三轧叽,四項(xiàng)上都有些許差異苗沧,具體為元素位置排列上的差異,在原RoPE公式中q0的結(jié)果是q0和q1這一對(duì)元素經(jīng)過三角函數(shù)變換而成的炭晒,但是在實(shí)際公式中q0是由q0和q64這一對(duì)形成的待逞,只需要把q1想像成q64則兩個(gè)公式完全等價(jià),那q1和q64互換對(duì)最終的結(jié)果影響嗎网严?答案是沒有影響识樱,RoPE對(duì)原始向量的改造本質(zhì)上是以一對(duì)元素為單位經(jīng)過旋轉(zhuǎn)矩陣運(yùn)算,將所有對(duì)的結(jié)果進(jìn)行拼接的過程震束,而到底是選擇連續(xù)的元素作為一對(duì)怜庸,還是其他的挑選方式都是可以的,只要是embedding維度為偶數(shù)垢村,且挑選的策略為不重復(fù)的一對(duì)割疾,最終Attention的內(nèi)積結(jié)果都能感知到相對(duì)位置信息,因?yàn)锳ttention滿足內(nèi)積線性疊加性肝断,至于誰(shuí)和誰(shuí)一組進(jìn)行疊加并不重要
杈曲。
在改造完Query和Key之后,將他們灌入注意力網(wǎng)絡(luò)胸懈,計(jì)算注意力權(quán)重再攜帶Value信息担扑,代碼如下
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
注意此處的注意力并沒有做任何的結(jié)構(gòu)調(diào)整
,和傳統(tǒng)的Transformer注意力的結(jié)構(gòu)一模一樣趣钱,RoPE的相對(duì)位置改造對(duì)天然適配下游注意力網(wǎng)絡(luò)涌献,另外Value信息沒有參加RoPE改造
,RoPE只對(duì)內(nèi)積過程中的Query和Key做改造首有。
旋轉(zhuǎn)位置編碼的推導(dǎo)
直接使用RoPE的結(jié)論在網(wǎng)絡(luò)結(jié)構(gòu)中使用起來不復(fù)雜燕垃,RoPE怎么來的需要經(jīng)過一系列公式推導(dǎo)枢劝,其中涉及復(fù)數(shù)的概念,包括復(fù)數(shù)的坐標(biāo)表示和三角表示卜壕,復(fù)數(shù)相乘運(yùn)算您旁,共軛復(fù)數(shù),歐拉公式和旋轉(zhuǎn)矩陣轴捎。本篇的講解會(huì)直接引用RoPE的作者博客Transformer升級(jí)之路:2鹤盒、博采眾長(zhǎng)的旋轉(zhuǎn)式位置編碼,對(duì)于作者在原文中省略的部分細(xì)節(jié)會(huì)做一定的補(bǔ)充侦副。
有了前文的鋪墊侦锯,作者的出發(fā)點(diǎn)是想通過一種絕對(duì)位置編碼的方式,讓Attention能夠自動(dòng)感知到相對(duì)位置秦驯,而不需要對(duì)Attention的結(jié)構(gòu)進(jìn)行改造尺碰。由于Attention是對(duì)兩個(gè)位置的token的向量進(jìn)行運(yùn)算,因此只需要在Attention之前译隘,對(duì)Query和Key向量進(jìn)行絕對(duì)位置編碼改造即可亲桥,跟Value沒有關(guān)系,我們?cè)O(shè)一個(gè)m位置的token1固耘,它的Query為q两曼,n位置的token2,它的Key為k,改造函數(shù)為f玻驻,則經(jīng)過注入位置信息改造之后的新向量為
這樣改造的目的是使得Attention內(nèi)積能夠自動(dòng)感知到相對(duì)位置信息,即內(nèi)積可以恒等轉(zhuǎn)化為一個(gè)函數(shù)偿枕,這個(gè)函數(shù)只和原始的Query璧瞬,Key,以及兩個(gè)token之間的距離m-n相關(guān)渐夸,令g為這個(gè)恒等變換函數(shù)嗤锉,則有以下公式
下面就是要找到一個(gè)改造函數(shù)f,使得以上這個(gè)恒等變換g成立墓塌。
作者首先從最簡(jiǎn)單的二維角度考慮瘟忱,假設(shè)q和k的embedding維度都是2維,將變換后的q苫幢,k用復(fù)數(shù)進(jìn)行表示访诱,其中第一維為復(fù)數(shù)的實(shí)部,第二維為復(fù)數(shù)的虛部韩肝,以一個(gè)[-2.1, 3.2]的二維向量為例触菜,復(fù)數(shù)形式表示如下
則兩者的內(nèi)積等于q和k的共軛復(fù)數(shù)相乘的實(shí)部,公式如下
公式中的Re代表復(fù)數(shù)的實(shí)部哀峻,f*代表共軛復(fù)數(shù)涡相,這里涉及復(fù)數(shù)的乘法和共軛復(fù)數(shù)
復(fù)數(shù)和共軛復(fù)數(shù)
復(fù)數(shù)z的坐標(biāo)表示為z=a+bi哲泊,其中a是復(fù)數(shù)的實(shí)部,b是復(fù)述的虛部催蝗,z的共軛復(fù)數(shù)是a-bi切威,即實(shí)部不變,虛部取相反數(shù)
復(fù)數(shù)的乘法
兩個(gè)復(fù)數(shù)相乘直接展開相乘即可丙号,z1=a+bi先朦,z2=c+di,則z1×z2=(ac-bd)+(bc+ad)i
根據(jù)以上兩個(gè)性質(zhì)槽袄,等式右側(cè)等于(ac+bd)+(bc-ad)i烙无,其實(shí)部為ac+bd,真好為兩向量對(duì)應(yīng)位置元素相乘再相加遍尺,因此該內(nèi)積公式成立截酷,等式聯(lián)立可得
把實(shí)部Re拿掉,f(q,m)和f(k,n)共軛相乘的結(jié)果是一個(gè)復(fù)數(shù)乾戏,設(shè)其結(jié)果為g迂苛,該復(fù)數(shù)也必定和q,k鼓择,m-n相關(guān)三幻,令下式為公式一
我們將三個(gè)復(fù)數(shù)用復(fù)數(shù)的三角形式表示,表示為向量的模長(zhǎng)和幅角形式呐能,令下式為公式二
其中R代表向量的模長(zhǎng)念搬,e的iθ次冪為歐拉公式,歐拉公式展開如下
和向量的模R相乘歐拉公式對(duì)應(yīng)復(fù)數(shù)的三角表示摆出,其中θ為幅角
下面的推導(dǎo)需要用到復(fù)數(shù)相乘的性質(zhì)
復(fù)數(shù)三角形式相乘
復(fù)數(shù)的三角形式朗徊,兩個(gè)復(fù)數(shù)相乘,模長(zhǎng)相乘偎漫,幅角相加爷恳。這個(gè)可以用三角表示的相乘展開證明,這里舉一個(gè)例子:復(fù)數(shù)z=1+√3i象踊,其中模長(zhǎng)為温亲,幅叫我為60度,如果z和z相乘杯矩,根據(jù)性質(zhì)栈虚,相乘的結(jié)果映射到坐標(biāo)系應(yīng)該模長(zhǎng)為4,幅角為120度史隆,因此z×z=-2+2√3i节芥,在坐標(biāo)系下的可視化如下
復(fù)數(shù)相乘的性質(zhì)
復(fù)數(shù)z再乘以z,在坐標(biāo)系上相當(dāng)于將z的模長(zhǎng)乘以2,并且逆時(shí)針旋轉(zhuǎn)了z的幅角60度头镊。
根據(jù)復(fù)數(shù)相乘的性質(zhì)蚣驼,因此等式一左邊兩個(gè)復(fù)數(shù)相乘的模相乘,角相加相艇,等式右邊也是一個(gè)復(fù)數(shù)颖杏,因此兩邊的模和角度應(yīng)該相等,則有
注意第二行為兩個(gè)θ角度相減坛芽,原因是f(k,n)取了共軛復(fù)數(shù)留储,因此幅角取負(fù)。接下來我們?nèi)∫粋€(gè)特例m=n=0的時(shí)候咙轩,令初始化階段0位置的向量就是向量本身不做任何變化获讳,則對(duì)于第一個(gè)式子有
同樣將m=n=0帶入第二個(gè)式子,則有
可得θ是一個(gè)關(guān)于位置參數(shù)m的函數(shù)活喊,且滿足關(guān)于m的等差數(shù)列關(guān)系丐膝,將求解的R和θ代入改造函數(shù)f的三角表示可得f的一般形式
e的imθ次冪根據(jù)歐拉公式展開實(shí)際該變換對(duì)應(yīng)著向量的旋轉(zhuǎn),所以稱之為“旋轉(zhuǎn)式位置編碼”钾菊,改寫成矩陣相乘的形式如下
將mθ看作一個(gè)參數(shù)帅矗,將旋轉(zhuǎn)矩陣以函數(shù)形式實(shí)現(xiàn),令二維向量坐標(biāo)為[1, 2]煞烫,將其旋轉(zhuǎn)60度的numpy實(shí)現(xiàn)如下
>>> import numpy as np
>>> def rotary_matrix(xita):
matrix = np.array([[np.cos(xita), -np.sin(xita)], [np.sin(xita), np.cos(xita)]])
return matrix
>>> m = rotary_matrix(np.pi / 3)
>>> one = np.array([[1], [2]])
>>> two = np.matmul(m, one)
>>> print(two)
array([[-1.23205081],
[ 1.8660254 ]])
以上代碼定義個(gè)參數(shù)xita浑此,若xita等于60度,則代表將原始的二維向量逆時(shí)針旋轉(zhuǎn)60度滞详,可以通過兩個(gè)向量?jī)?nèi)積除以向量的乘積的模來驗(yàn)證旋轉(zhuǎn)之后兩個(gè)向量的夾角凛俱,首先驗(yàn)證旋轉(zhuǎn)前后向量的模長(zhǎng)不變
>>> np.linalg.norm(one)
2.23606797749979
>>> np.linalg.norm(one)
2.23606797749979
旋轉(zhuǎn)之后兩個(gè)向量的內(nèi)積除以模乘積等于0.5,因此旋轉(zhuǎn)的夾角為60度
>>> np.dot(one.T, two) / (np.linalg.norm(one) * np.linalg.norm(two))
array([[0.5]])
整個(gè)旋轉(zhuǎn)過程可視化如下
當(dāng)向量為二維時(shí)料饥,θ下標(biāo)為0最冰,因此θ的實(shí)際結(jié)果為1,此時(shí)單詞位置m控制了旋轉(zhuǎn)的幅度稀火,m越大旋轉(zhuǎn)幅度越大
token | 位置 | 逆時(shí)針旋轉(zhuǎn)角度 |
---|---|---|
我 | 0 | 0度 |
愛 | 1 | 57度 |
中 | 2 | 114度 |
國(guó) | 3 | 171度 |
... | ... | ... |
從旋轉(zhuǎn)矩陣的角度,本質(zhì)上赌朋,RoPE是對(duì)各個(gè)位置的token向量根據(jù)自身位置m計(jì)算角度做逆時(shí)針旋轉(zhuǎn)凰狞,在Attention的內(nèi)積操作中,內(nèi)積能夠感知到旋轉(zhuǎn)之后兩個(gè)向量之間的夾角沛慢,這個(gè)夾角就是相對(duì)位置信息
赡若。
此時(shí)二維向量的RoPE得證,由于內(nèi)積滿足線性疊加性团甲,因此任意偶數(shù)維的向量都可以表示為二維情形的拼接逾冬,因此RoPE的最終公式如下,回到開頭介紹RoPE的實(shí)現(xiàn)公式
全文完畢。