LLM面面觀之MoE

1. 背景

根據(jù)本qiang~最新的趨勢觀察鸵鸥,基于MoE架構(gòu)的開源大模型越來越多,比如馬斯克的Grok-1(314B), Qwen1.5-MoE-A2.7B等丹皱,因此想探究一下MoE里面的部分細(xì)節(jié)脂男。

此文是本qiang~針對大語言模型的MoE的整理,包括原理种呐、流程及部分源碼。

2. MoE原理

MoE的流行源于”歐洲的OpenAI”

Mistral AI發(fā)布的論文及模型《Mixtral of Experts》弃甥,評測集上的效果吊打眾多開源模型爽室,如Llama 2 70B和GPT3.5。

《Mixtral of Experts》基礎(chǔ)模型使用的是Mistral

AI自研的Mistral 7B淆攻,該模型的特點(diǎn)包括:滑窗注意力(Sliding Window Aattention), 滾動緩沖區(qū)緩存(Rolling

Buffer Cache)以及預(yù)填充-分塊(Pre-fill and Chunking)阔墩,具體細(xì)節(jié)可以查閱文末的論文地址。

本文以《Mixtral of

Experts》為引子瓶珊,探究MoE的相關(guān)細(xì)節(jié)啸箫,MoE的原理如下圖所示:


圖2.1 MoE的原理

(1) Transformers架構(gòu)中的每一層中的FFN網(wǎng)絡(luò)均替換為了8個FFN(專家),且由一個網(wǎng)關(guān)路由(gate

router)進(jìn)行控制

(2) 針對每一個token伞芹,每一層的網(wǎng)關(guān)路由僅選擇其中的2個FFN(專家)來處理當(dāng)前狀態(tài)并進(jìn)行加權(quán)輸出

(3) 結(jié)果就是忘苛,每一個token訪問了47B參數(shù)蝉娜,但是在推理階段僅僅使用了13B的激活參數(shù)(即,只使用2個專家扎唾,凍結(jié)其他6個專家)召川。

(4) 與Dropout機(jī)制對比,Dropout讓部分神經(jīng)元失活胸遇,而MoE是讓部分專家失活荧呐。

3. 源碼

本qiang~研讀并嘗試執(zhí)行了Mistral官網(wǎng)的github推理代碼,該代碼框架非常適合新手纸镊,無他倍阐,只因其幾乎只是在torch上層做的封裝,很少引擎其他第三方庫逗威,不像transformers峰搪,功能強(qiáng)大,但不適合新手研讀代碼…

為了普適性庵楷,下面的代碼截取了transformers框架中的代碼罢艾。

首先看下通用Transformers中FFN中的代碼模塊,代碼位置在transformers.models.mistral.modeling_mistral,主要流程是:

(1) 先經(jīng)過gate_proj和up_proj的2個[hidden_size,

intermediate_size]的線性轉(zhuǎn)換

(2) 使用激活函數(shù)對gate_proj進(jìn)行激活

(3) 二者的內(nèi)積再經(jīng)過down_proj線性轉(zhuǎn)換尽纽。


class MistralMLP(nn.Module):

??? def __init__(self,? config):

??????? super().__init__()

??????? self.config = config

??????? self.hidden_size =? config.hidden_size


? self.intermediate_size = config.intermediate_size

??????? self.gate_proj =? nn.Linear(self.hidden_size, self.intermediate_size, bias=False)

??????? self.up_proj =? nn.Linear(self.hidden_size, self.intermediate_size, bias=False)

??????? self.down_proj =? nn.Linear(self.intermediate_size, self.hidden_size, bias=False)

??? ????self.act_fn = ACT2FN[config.hidden_act]


??? def forward(self, x):

??????? return? self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


再來看下MoE中的專家模塊咐蚯,代碼位置在transformers.models.mixtral.modeling_mixtral,主要流程是:

(1) 首先經(jīng)過網(wǎng)關(guān)路由self.gate

(2) 然后選擇其中2個專家弄贿,并歸一化

(3) 之后遍歷每個專家網(wǎng)絡(luò)春锋,并按照expert_mask進(jìn)行篩選

(4) 如果expert_mask有值,則選擇指定部分的隱藏層進(jìn)行FFN操作差凹,且輸出結(jié)果進(jìn)行加權(quán)

(5) 最后原地增加先前初始化的最終結(jié)果變量final_hidden_states


class MixtralSparseMoeBlock(nn.Module):


??? def __init__(self,? config):

??????? super().__init__()

??????? self.hidden_dim =? config.hidden_size

??????? self.ffn_dim =? config.intermediate_size

??????? self.num_experts =? config.num_local_experts

??????? self.top_k =? config.num_experts_per_tok


??????? # gating

??????? self.gate =? nn.Linear(self.hidden_dim, self.num_experts, bias=False)


??????? self.experts =? nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in? range(self.num_experts)])


??? def forward(self,? hidden_states: torch.Tensor) -> torch.Tensor:

??????? """? """

??????? batch_size,? sequence_length, hidden_dim = hidden_states.shape

??????? hidden_states =? hidden_states.view(-1, hidden_dim)

??????? # router_logits:? (batch * sequence_length, n_experts)

??????? router_logits =? self.gate(hidden_states)


??????? routing_weights =? F.softmax(router_logits, dim=1, dtype=torch.float)

??????? routing_weights,? selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)

??????? routing_weights /=? routing_weights.sum(dim=-1, keepdim=True)

??????? # we cast back to? the input dtype

??????? routing_weights =? routing_weights.to(hidden_states.dtype)


??????? final_hidden_states? = torch.zeros(

??????????? (batch_size *? sequence_length, hidden_dim), dtype=hidden_states.dtype,? device=hidden_states.device

??????? )


??????? # One hot encode the? selected experts to create an expert mask

??????? # this will be used? to easily index which expert is going to be sollicitated

??????? expert_mask =? torch.nn.functional.one_hot(selected_experts,? num_classes=self.num_experts).permute(2, 1, 0)


??????? # Loop over all? available experts in the model and perform the computation on each expert

??????? for expert_idx in? range(self.num_experts):

??????????? expert_layer =? self.experts[expert_idx]

??????????? idx, top_x =? torch.where(expert_mask[expert_idx])


??????????? if? top_x.shape[0] == 0:

??????????????? continue


??????????? # in torch it is? faster to index using lists than torch tensors

??????????? top_x_list =? top_x.tolist()

??????????? idx_list =? idx.tolist()


??????????? # Index the? correct hidden states and compute the expert hidden state for

??????????? # the current? expert. We need to make sure to multiply the output hidden

??????????? # states by? `routing_weights` on the corresponding tokens (top-1 and top-2)

??????????? current_state =? hidden_states[None, top_x_list].reshape(-1, hidden_dim)


? current_hidden_states = expert_layer(current_state) *? routing_weights[top_x_list, idx_list, None]


??????????? # However? `index_add_` only support torch tensors for indexing so we'll use

??????????? # the `top_x`? tensor here.


? final_hidden_states.index_add_(0, top_x,? current_hidden_states.to(hidden_states.dtype))

??????? final_hidden_states? = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

???? ???return final_hidden_states, router_logits


其中MixtralBlockSparseTop2MLP代碼如下期奔,可以看到和傳統(tǒng)MistralMLP內(nèi)容完全一致。


class MixtralBlockSparseTop2MLP(nn.Module):

??? def __init__(self,? config: MixtralConfig):

??????? super().__init__()

??????? self.ffn_dim =? config.intermediate_size

??????? self.hidden_dim =? config.hidden_size


??????? self.w1 =? nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)

??????? self.w2 =? nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)

??????? self.w3 = nn.Linear(self.hidden_dim,? self.ffn_dim, bias=False)


??????? self.act_fn =? ACT2FN[config.hidden_act]


??? def forward(self,? hidden_states):


? current_hidden_states = self.act_fn(self.w1(hidden_states)) *? self.w3(hidden_states)

??????? current_hidden_states? = self.w2(current_hidden_states)

??????? return? current_hidden_states


4. MoE微調(diào)

由于MoE只是將每一層的FFN改變?yōu)榱嗣恳粚拥膅ate網(wǎng)關(guān)路由+8個FFN專家危尿,且gate網(wǎng)關(guān)路由和8個專家內(nèi)部均為線性運(yùn)算牺弹,所以可以無縫地結(jié)合LoRA、QLoRA進(jìn)行指令微調(diào)按灶。

可以參考開源項目:https://github.com/yangjianxin1/Firefly

5. 答疑解惑

(1) 問:MoE 8*7B的模型是56B參數(shù)眶拉?

答:MoE 8*7B的參數(shù)量是47B,而不是56B济欢,原因是每一層除了8個專家網(wǎng)絡(luò)外赠堵,其他層均是復(fù)用的。

(2) 問:MoE的基礎(chǔ)模型是Mistral7B?

答:不是法褥,MoE的模型架構(gòu)與Mistral

7B相同茫叭,但其中的FFN替換為了8個FFN,且MoE是基于多語言數(shù)據(jù)集預(yù)訓(xùn)練而來的半等。

(3) MoE的稀疏性(sparse)體現(xiàn)在哪里揍愁?

答:在訓(xùn)練和推理時呐萨,同時只有兩個專家網(wǎng)絡(luò)會被激活,進(jìn)行前向計算吗垮,其它專家網(wǎng)絡(luò)處于失活狀態(tài)垛吗。

6. 總結(jié)

一句話足矣~

本文主要針對大語言模型的MoE,包括原理及部分源碼烁登。

此外怯屉,建議大家可以針對源碼進(jìn)行運(yùn)行,關(guān)于源碼饵沧,歡迎大家一塊交流锨络。

7. 參考

(1) Mistral 7B:https://arxiv.org/pdf/2310.06825v1.pdf

(2) MoE:https://arxiv.org/pdf/2401.04088v1.pdf

(3) MoE開源指令微調(diào)框架Firefly:https://github.com/yangjianxin1/Firefly

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市狼牺,隨后出現(xiàn)的幾起案子羡儿,更是在濱河造成了極大的恐慌,老刑警劉巖是钥,帶你破解...
    沈念sama閱讀 206,839評論 6 482
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件掠归,死亡現(xiàn)場離奇詭異,居然都是意外死亡悄泥,警方通過查閱死者的電腦和手機(jī)虏冻,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,543評論 2 382
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來弹囚,“玉大人厨相,你說我怎么就攤上這事∨葛模” “怎么了蛮穿?”我有些...
    開封第一講書人閱讀 153,116評論 0 344
  • 文/不壞的土叔 我叫張陵,是天一觀的道長毁渗。 經(jīng)常有香客問我践磅,道長,這世上最難降的妖魔是什么灸异? 我笑而不...
    開封第一講書人閱讀 55,371評論 1 279
  • 正文 為了忘掉前任音诈,我火速辦了婚禮,結(jié)果婚禮上绎狭,老公的妹妹穿的比我還像新娘。我一直安慰自己褥傍,他們只是感情好儡嘶,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,384評論 5 374
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著恍风,像睡著了一般蹦狂。 火紅的嫁衣襯著肌膚如雪誓篱。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,111評論 1 285
  • 那天凯楔,我揣著相機(jī)與錄音窜骄,去河邊找鬼。 笑死摆屯,一個胖子當(dāng)著我的面吹牛邻遏,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播虐骑,決...
    沈念sama閱讀 38,416評論 3 400
  • 文/蒼蘭香墨 我猛地睜開眼准验,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了廷没?” 一聲冷哼從身側(cè)響起糊饱,我...
    開封第一講書人閱讀 37,053評論 0 259
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎颠黎,沒想到半個月后另锋,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 43,558評論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡狭归,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,007評論 2 325
  • 正文 我和宋清朗相戀三年夭坪,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片唉铜。...
    茶點(diǎn)故事閱讀 38,117評論 1 334
  • 序言:一個原本活蹦亂跳的男人離奇死亡台舱,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出潭流,到底是詐尸還是另有隱情竞惋,我是刑警寧澤,帶...
    沈念sama閱讀 33,756評論 4 324
  • 正文 年R本政府宣布灰嫉,位于F島的核電站拆宛,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏讼撒。R本人自食惡果不足惜浑厚,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,324評論 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望根盒。 院中可真熱鬧钳幅,春花似錦、人聲如沸炎滞。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,315評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽册赛。三九已至钠导,卻和暖如春震嫉,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背牡属。 一陣腳步聲響...
    開封第一講書人閱讀 31,539評論 1 262
  • 我被黑心中介騙來泰國打工票堵, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人逮栅。 一個月前我還...
    沈念sama閱讀 45,578評論 2 355
  • 正文 我出身青樓悴势,卻偏偏與公主長得像,于是被迫代替她去往敵國和親证芭。 傳聞我的和親對象是個殘疾皇子瞳浦,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,877評論 2 345

推薦閱讀更多精彩內(nèi)容