1. 背景
最近本qiang~老看到一些關(guān)于大語(yǔ)言模型的DPO询兴、RLHF算法,但都有些云里霧里起趾,因此靜下心來收集資料诗舰、研讀論文,并執(zhí)行了下開源代碼训裆,以便加深印象眶根。
此文是本qiang~針對(duì)大語(yǔ)言模型的DPO算法的整理,包括原理边琉、流程及部分源碼属百。
2. DPO vs RLHF
上圖左邊是RLHF算法,右邊為DPO算法变姨,兩圖的差異對(duì)比即可體現(xiàn)出DPO的改進(jìn)之處族扰。
1. RLHF算法包含獎(jiǎng)勵(lì)模型(reward
model)和策略模型(policy model,也稱為演員模型钳恕,actor model)别伏,基于偏好數(shù)據(jù)以及強(qiáng)化學(xué)習(xí)不斷迭代優(yōu)化策略模型的過程。
2. DPO算法不包含獎(jiǎng)勵(lì)模型和強(qiáng)化學(xué)習(xí)過程忧额,直接通過偏好數(shù)據(jù)進(jìn)行微調(diào)厘肮,將強(qiáng)化學(xué)習(xí)過程直接轉(zhuǎn)換為SFT過程,因此整個(gè)訓(xùn)練過程簡(jiǎn)單睦番、高效类茂,主要的改進(jìn)之處體現(xiàn)在于損失函數(shù)。
PS:
1. 偏好數(shù)據(jù)托嚣,可以表示為三元組(提示語(yǔ)prompt, 良好回答chosen, 一般回答rejected)巩检。論文中的chosen表示為下標(biāo)w(即win),rejected表示為下標(biāo)l(即lose)
2. RLHF常使用PPO作為基礎(chǔ)算法示启,整體流程包含了4個(gè)模型兢哭,且通常訓(xùn)練過程中需要針對(duì)訓(xùn)練的actor model進(jìn)行采樣,因此訓(xùn)練起來夫嗓,穩(wěn)定性迟螺、效率廊谓、效果不易控制竿裂。
1) actor model/policy
model: 待訓(xùn)練的模型,通常是SFT訓(xùn)練后的模型作為初始化
2) reference
model: 參考模型链韭,也是經(jīng)SFT訓(xùn)練后的模型進(jìn)行初始化排霉,且通常與actor model是同一個(gè)模型窍株,且模型凍結(jié),不參與訓(xùn)練,其作用是在強(qiáng)化學(xué)習(xí)過程中球订,保障actor model與reference model的分布差異不宜過大后裸。
3) reward model:獎(jiǎng)勵(lì)模型,用于提供每個(gè)狀態(tài)或狀態(tài)動(dòng)作對(duì)的即時(shí)獎(jiǎng)勵(lì)信號(hào)辙售。
4) Critic model:作用是估計(jì)狀態(tài)或狀態(tài)動(dòng)作對(duì)的長(zhǎng)期價(jià)值轻抱,也稱為狀態(tài)值函數(shù)或動(dòng)作值函數(shù)。
3. DPO算法僅包含RLHF中的兩個(gè)模型旦部,即演員模型(actor
model)以及參考(reference model)祈搜,且訓(xùn)練過程中不需要進(jìn)行數(shù)據(jù)采樣。
4. RLHF可以參考附件中的引文
3. DPO的損失函數(shù)
如何將RLHF的Reward model過程簡(jiǎn)化為上式士八,作者花了大量篇幅進(jìn)行了推導(dǎo)容燕,感興趣的讀者可以參考附件DPO的論文。
DPO算法的目的是最大化獎(jiǎng)勵(lì)模型(此處的獎(jiǎng)勵(lì)模型即為訓(xùn)練的策略)婚度,使得獎(jiǎng)勵(lì)模型對(duì)chosen和rejected數(shù)據(jù)的差值最大蘸秘,進(jìn)而學(xué)到人類偏好。
上式的后半部分通過對(duì)數(shù)函數(shù)運(yùn)算規(guī)則蝗茁,可以進(jìn)行如下轉(zhuǎn)化醋虏。
轉(zhuǎn)化后的公式和源代碼中的計(jì)算函數(shù)中的公式是一致的。
其中左半部分是訓(xùn)練的policy模型選擇chosen優(yōu)先于rejected哮翘,右半部分是凍結(jié)的reference模型選擇chosen優(yōu)先于rejected颈嚼,二者的差值可類似于KL散度,保障actor模型的分布與reference模型的分布不會(huì)有較大的差異饭寺。
4. 微調(diào)流程
上圖展示了DPO微調(diào)的大致流程阻课,其中Trained
LM即為策略模型,F(xiàn)rozen LM即為參考模型艰匙,二者均是先進(jìn)行SFT微調(diào)得到的模型進(jìn)行初始化限煞,其中Trained LM需要進(jìn)行訓(xùn)練,F(xiàn)rozen LM不參與訓(xùn)練员凝。
兩個(gè)模型分別針對(duì)chosen和rejected進(jìn)行預(yù)測(cè)獲取對(duì)應(yīng)的得分署驻,再通過DPO的損失函數(shù)進(jìn)行損失計(jì)算,進(jìn)而不斷的迭代優(yōu)化健霹。
5. 源碼
源碼參考代碼:https://github.com/eric-mitchell/direct-preference-optimization
5.1 DPO損失函數(shù)
def preference_loss(policy_chosen_logps: torch.FloatTensor,
? policy_rejected_logps: torch.FloatTensor,
? reference_chosen_logps: torch.FloatTensor,
? reference_rejected_logps: torch.FloatTensor,
??????????????????? beta:? float,
? label_smoothing: float = 0.0,
??????????????????? ipo: bool? = False,
? reference_free: bool = False) -> Tuple[torch.FloatTensor,? torch.FloatTensor, torch.FloatTensor]:
??? # policy_chosen_logps:訓(xùn)練模型對(duì)于chosen經(jīng)過log后logits
???????? #? policy_rejected_logps:訓(xùn)練模型對(duì)于rejected經(jīng)過log后logits
???????? #? reference_chosen_logps:訓(xùn)練模型對(duì)于chosen經(jīng)過log后logits
???????? #? reference_rejected_logps:訓(xùn)練模型對(duì)于rejected經(jīng)過log后logits
???????? # beta: policy和reference的差異性控制參數(shù)
???????? # actor模型選擇chosen優(yōu)先于rejected
??? pi_logratios =? policy_chosen_logps - policy_rejected_logps
???????? # reference模型選擇chosen優(yōu)先于rejected
??? ref_logratios =? reference_chosen_logps - reference_rejected_logps
??? if reference_free:
??????? ref_logratios = 0
???????? #差值可類似于KL散度旺上,保障actor模型的分布與reference模型的分布不會(huì)有較大的差異
??? logits = pi_logratios -? ref_logratios? # also known as? h_{\pi_\theta}^{y_w,y_l}
??? if ipo:
??????? losses = (logits -? 1/(2 * beta)) ** 2? # Eq. 17 of? https://arxiv.org/pdf/2310.12036v2.pdf
??? else:
??????? # Eq. 3? https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7? of https://arxiv.org/pdf/2305.18290.pdf)
????????????????? #? label_smoothing為0,對(duì)應(yīng)的DPO論文的算法
??????? losses =? -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta *? logits) * label_smoothing
???????? # chosen和rejected的獎(jiǎng)勵(lì)
??? chosen_rewards = beta *? (policy_chosen_logps - reference_chosen_logps).detach()
??? rejected_rewards = beta? * (policy_rejected_logps - reference_rejected_logps).detach()
??? return losses,? chosen_rewards, rejected_rewards
5.2 批次訓(xùn)練過程
def get_batch_metrics(self, batch: Dict[str, Union[List,? torch.LongTensor]], loss_config: DictConfig, train=True):
???????? """Compute? the SFT or DPO loss and other metrics for the given batch of? inputs."""
???????? if loss_config.name? in {'dpo', 'ipo'}:
????????????????? # policy模型針對(duì)chosen和rejected進(jìn)行預(yù)測(cè)
????????????????? policy_chosen_logps,? policy_rejected_logps = self.concatenated_forward(self.policy, batch)
????????????????? with? torch.no_grad():
????????????????????????? #? reference模型針對(duì)chosen和rejected進(jìn)行預(yù)測(cè)
????????????????????????? reference_chosen_logps,? reference_rejected_logps = self.concatenated_forward(self.reference_model,? batch)
????????????????? if? loss_config.name == 'dpo':
????????????????????????? loss_kwargs? = {'beta': loss_config.beta, 'reference_free': loss_config.reference_free,? 'label_smoothing': loss_config.label_smoothing, 'ipo': False}
????????????????? elif? loss_config.name == 'ipo':
????????????????????????? loss_kwargs? = {'beta': loss_config.beta, 'ipo': True}
????????????????? else:
????????????????????????? raise? ValueError(f'unknown loss {loss_config.name}')
????????????????? #損失計(jì)算
????????????????? losses,? chosen_rewards, rejected_rewards = preference_loss(
????????????????????????? policy_chosen_logps,? policy_rejected_logps, reference_chosen_logps, reference_rejected_logps,? **loss_kwargs)
????????????????? reward_accuracies? = (chosen_rewards > rejected_rewards).float()
???????? elif? loss_config.name == 'sft':
????????????????? policy_chosen_logits? = self.policy(batch['chosen_input_ids'],? attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32)
????????????????? policy_chosen_logps? = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'],? average_log_prob=False)
????????????????? losses =? -policy_chosen_logps
???????? return losses.mean()
5.3 LM的交叉熵計(jì)算
def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor,? average_log_prob: bool = False) -> torch.FloatTensor:
??? #經(jīng)模型后的logits進(jìn)行批量計(jì)算logps
??? assert logits.shape[:-1]? == labels.shape
???????? #基于先前的token預(yù)測(cè)下一個(gè)token
??? labels = labels[:,? 1:].clone()
??? logits = logits[:, :-1,? :]
??? loss_mask = (labels !=? -100)
??? # dummy token; we'll? ignore the losses on these tokens later
??? labels[labels == -100] =? 0
???????? #交叉熵函數(shù)
??? per_token_logps =? torch.gather(logits.log_softmax(-1), dim=2,? index=labels.unsqueeze(2)).squeeze(2)
??? if average_log_prob:
??????? return? (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
??? else:
??????? return? (per_token_logps * loss_mask).sum(-1)
5.4 其他注意
1. hugging face設(shè)置代理
源碼會(huì)從hugging face中下載英文語(yǔ)料和模型骤公,由于網(wǎng)絡(luò)限制,因此設(shè)置代理映射扬跋,將HF_ENDPOINT設(shè)置為https://hf-mirror.com阶捆,即設(shè)置: os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
2. 如果僅想要熟悉DPO整體流程,可以下載較小的生成式模型,如BLOOM 560M洒试,GPT2等
6. 總結(jié)
一句話足矣~
本文主要針對(duì)大語(yǔ)言模型的DPO算法的整理倍奢,包括原理、流程及部分源碼垒棋。
此外卒煞,建議大家可以針對(duì)源碼進(jìn)行運(yùn)行,源碼的歡迎大家一塊交流叼架。
7. 參考
(1) RLHF:https://blog.csdn.net/v_JULY_v/article/details/128579457
(2) DPO論文: https://arxiv.org/pdf/2305.18290v2.pdf
(3) DPO代碼: https://github.com/eric-mitchell/direct-preference-optimization
(4) DPO理解1:https://medium.com/@joaolages/direct-preference-optimization-dpo-622fc1f18707
(5) DPO理解2: https://zhuanlan.zhihu.com/p/669825918