1. 背景
根據(jù)本qiang~最新的趨勢觀察鸵鸥,基于MoE架構(gòu)的開源大模型越來越多,比如馬斯克的Grok-1(314B), Qwen1.5-MoE-A2.7B等丹皱,因此想探究一下MoE里面的部分細(xì)節(jié)脂男。
此文是本qiang~針對大語言模型的MoE的整理,包括原理种呐、流程及部分源碼。
2. MoE原理
MoE的流行源于”歐洲的OpenAI”
Mistral AI發(fā)布的論文及模型《Mixtral of Experts》弃甥,評測集上的效果吊打眾多開源模型爽室,如Llama 2 70B和GPT3.5。
《Mixtral of Experts》基礎(chǔ)模型使用的是Mistral
AI自研的Mistral 7B淆攻,該模型的特點(diǎn)包括:滑窗注意力(Sliding Window Aattention), 滾動緩沖區(qū)緩存(Rolling
Buffer Cache)以及預(yù)填充-分塊(Pre-fill and Chunking)阔墩,具體細(xì)節(jié)可以查閱文末的論文地址。
本文以《Mixtral of
Experts》為引子瓶珊,探究MoE的相關(guān)細(xì)節(jié)啸箫,MoE的原理如下圖所示:
(1) Transformers架構(gòu)中的每一層中的FFN網(wǎng)絡(luò)均替換為了8個FFN(專家),且由一個網(wǎng)關(guān)路由(gate
router)進(jìn)行控制
(2) 針對每一個token伞芹,每一層的網(wǎng)關(guān)路由僅選擇其中的2個FFN(專家)來處理當(dāng)前狀態(tài)并進(jìn)行加權(quán)輸出
(3) 結(jié)果就是忘苛,每一個token訪問了47B參數(shù)蝉娜,但是在推理階段僅僅使用了13B的激活參數(shù)(即,只使用2個專家扎唾,凍結(jié)其他6個專家)召川。
(4) 與Dropout機(jī)制對比,Dropout讓部分神經(jīng)元失活胸遇,而MoE是讓部分專家失活荧呐。
3. 源碼
本qiang~研讀并嘗試執(zhí)行了Mistral官網(wǎng)的github推理代碼,該代碼框架非常適合新手纸镊,無他倍阐,只因其幾乎只是在torch上層做的封裝,很少引擎其他第三方庫逗威,不像transformers峰搪,功能強(qiáng)大,但不適合新手研讀代碼…
為了普適性庵楷,下面的代碼截取了transformers框架中的代碼罢艾。
首先看下通用Transformers中FFN中的代碼模塊,代碼位置在transformers.models.mistral.modeling_mistral,主要流程是:
(1) 先經(jīng)過gate_proj和up_proj的2個[hidden_size,
intermediate_size]的線性轉(zhuǎn)換
(2) 使用激活函數(shù)對gate_proj進(jìn)行激活
(3) 二者的內(nèi)積再經(jīng)過down_proj線性轉(zhuǎn)換尽纽。
class MistralMLP(nn.Module):
??? def __init__(self,? config):
??????? super().__init__()
??????? self.config = config
??????? self.hidden_size =? config.hidden_size
? self.intermediate_size = config.intermediate_size
??????? self.gate_proj =? nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
??????? self.up_proj =? nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
??????? self.down_proj =? nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
??? ????self.act_fn = ACT2FN[config.hidden_act]
??? def forward(self, x):
??????? return? self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
再來看下MoE中的專家模塊咐蚯,代碼位置在transformers.models.mixtral.modeling_mixtral,主要流程是:
(1) 首先經(jīng)過網(wǎng)關(guān)路由self.gate
(2) 然后選擇其中2個專家弄贿,并歸一化
(3) 之后遍歷每個專家網(wǎng)絡(luò)春锋,并按照expert_mask進(jìn)行篩選
(4) 如果expert_mask有值,則選擇指定部分的隱藏層進(jìn)行FFN操作差凹,且輸出結(jié)果進(jìn)行加權(quán)
(5) 最后原地增加先前初始化的最終結(jié)果變量final_hidden_states
class MixtralSparseMoeBlock(nn.Module):
??? def __init__(self,? config):
??????? super().__init__()
??????? self.hidden_dim =? config.hidden_size
??????? self.ffn_dim =? config.intermediate_size
??????? self.num_experts =? config.num_local_experts
??????? self.top_k =? config.num_experts_per_tok
??????? # gating
??????? self.gate =? nn.Linear(self.hidden_dim, self.num_experts, bias=False)
??????? self.experts =? nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in? range(self.num_experts)])
??? def forward(self,? hidden_states: torch.Tensor) -> torch.Tensor:
??????? """? """
??????? batch_size,? sequence_length, hidden_dim = hidden_states.shape
??????? hidden_states =? hidden_states.view(-1, hidden_dim)
??????? # router_logits:? (batch * sequence_length, n_experts)
??????? router_logits =? self.gate(hidden_states)
??????? routing_weights =? F.softmax(router_logits, dim=1, dtype=torch.float)
??????? routing_weights,? selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
??????? routing_weights /=? routing_weights.sum(dim=-1, keepdim=True)
??????? # we cast back to? the input dtype
??????? routing_weights =? routing_weights.to(hidden_states.dtype)
??????? final_hidden_states? = torch.zeros(
??????????? (batch_size *? sequence_length, hidden_dim), dtype=hidden_states.dtype,? device=hidden_states.device
??????? )
??????? # One hot encode the? selected experts to create an expert mask
??????? # this will be used? to easily index which expert is going to be sollicitated
??????? expert_mask =? torch.nn.functional.one_hot(selected_experts,? num_classes=self.num_experts).permute(2, 1, 0)
??????? # Loop over all? available experts in the model and perform the computation on each expert
??????? for expert_idx in? range(self.num_experts):
??????????? expert_layer =? self.experts[expert_idx]
??????????? idx, top_x =? torch.where(expert_mask[expert_idx])
??????????? if? top_x.shape[0] == 0:
??????????????? continue
??????????? # in torch it is? faster to index using lists than torch tensors
??????????? top_x_list =? top_x.tolist()
??????????? idx_list =? idx.tolist()
??????????? # Index the? correct hidden states and compute the expert hidden state for
??????????? # the current? expert. We need to make sure to multiply the output hidden
??????????? # states by? `routing_weights` on the corresponding tokens (top-1 and top-2)
??????????? current_state =? hidden_states[None, top_x_list].reshape(-1, hidden_dim)
? current_hidden_states = expert_layer(current_state) *? routing_weights[top_x_list, idx_list, None]
??????????? # However? `index_add_` only support torch tensors for indexing so we'll use
??????????? # the `top_x`? tensor here.
? final_hidden_states.index_add_(0, top_x,? current_hidden_states.to(hidden_states.dtype))
??????? final_hidden_states? = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
???? ???return final_hidden_states, router_logits
其中MixtralBlockSparseTop2MLP代碼如下期奔,可以看到和傳統(tǒng)MistralMLP內(nèi)容完全一致。
class MixtralBlockSparseTop2MLP(nn.Module):
??? def __init__(self,? config: MixtralConfig):
??????? super().__init__()
??????? self.ffn_dim =? config.intermediate_size
??????? self.hidden_dim =? config.hidden_size
??????? self.w1 =? nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
??????? self.w2 =? nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
??????? self.w3 = nn.Linear(self.hidden_dim,? self.ffn_dim, bias=False)
??????? self.act_fn =? ACT2FN[config.hidden_act]
??? def forward(self,? hidden_states):
? current_hidden_states = self.act_fn(self.w1(hidden_states)) *? self.w3(hidden_states)
??????? current_hidden_states? = self.w2(current_hidden_states)
??????? return? current_hidden_states
4. MoE微調(diào)
由于MoE只是將每一層的FFN改變?yōu)榱嗣恳粚拥膅ate網(wǎng)關(guān)路由+8個FFN專家危尿,且gate網(wǎng)關(guān)路由和8個專家內(nèi)部均為線性運(yùn)算牺弹,所以可以無縫地結(jié)合LoRA、QLoRA進(jìn)行指令微調(diào)按灶。
可以參考開源項目:https://github.com/yangjianxin1/Firefly
5. 答疑解惑
(1) 問:MoE 8*7B的模型是56B參數(shù)眶拉?
答:MoE 8*7B的參數(shù)量是47B,而不是56B济欢,原因是每一層除了8個專家網(wǎng)絡(luò)外赠堵,其他層均是復(fù)用的。
(2) 問:MoE的基礎(chǔ)模型是Mistral7B?
答:不是法褥,MoE的模型架構(gòu)與Mistral
7B相同茫叭,但其中的FFN替換為了8個FFN,且MoE是基于多語言數(shù)據(jù)集預(yù)訓(xùn)練而來的半等。
(3) MoE的稀疏性(sparse)體現(xiàn)在哪里揍愁?
答:在訓(xùn)練和推理時呐萨,同時只有兩個專家網(wǎng)絡(luò)會被激活,進(jìn)行前向計算吗垮,其它專家網(wǎng)絡(luò)處于失活狀態(tài)垛吗。
6. 總結(jié)
一句話足矣~
本文主要針對大語言模型的MoE,包括原理及部分源碼烁登。
此外怯屉,建議大家可以針對源碼進(jìn)行運(yùn)行,關(guān)于源碼饵沧,歡迎大家一塊交流锨络。
7. 參考
(1) Mistral 7B:https://arxiv.org/pdf/2310.06825v1.pdf
(2) MoE:https://arxiv.org/pdf/2401.04088v1.pdf
(3) MoE開源指令微調(diào)框架Firefly:https://github.com/yangjianxin1/Firefly