LLM面面觀之RLHF平替算法DPO

1. 背景

最近本qiang~老看到一些關(guān)于大語(yǔ)言模型的DPO询兴、RLHF算法,但都有些云里霧里起趾,因此靜下心來收集資料诗舰、研讀論文,并執(zhí)行了下開源代碼训裆,以便加深印象眶根。

此文是本qiang~針對(duì)大語(yǔ)言模型的DPO算法的整理,包括原理边琉、流程及部分源碼属百。

2. DPO vs RLHF

RLHF vs DPO

上圖左邊是RLHF算法,右邊為DPO算法变姨,兩圖的差異對(duì)比即可體現(xiàn)出DPO的改進(jìn)之處族扰。

1. RLHF算法包含獎(jiǎng)勵(lì)模型(reward

model)和策略模型(policy model,也稱為演員模型钳恕,actor model)别伏,基于偏好數(shù)據(jù)以及強(qiáng)化學(xué)習(xí)不斷迭代優(yōu)化策略模型的過程。

2. DPO算法不包含獎(jiǎng)勵(lì)模型和強(qiáng)化學(xué)習(xí)過程忧额,直接通過偏好數(shù)據(jù)進(jìn)行微調(diào)厘肮,將強(qiáng)化學(xué)習(xí)過程直接轉(zhuǎn)換為SFT過程,因此整個(gè)訓(xùn)練過程簡(jiǎn)單睦番、高效类茂,主要的改進(jìn)之處體現(xiàn)在于損失函數(shù)。

PS:

1. 偏好數(shù)據(jù)托嚣,可以表示為三元組(提示語(yǔ)prompt, 良好回答chosen, 一般回答rejected)巩检。論文中的chosen表示為下標(biāo)w(即win),rejected表示為下標(biāo)l(即lose)

2. RLHF常使用PPO作為基礎(chǔ)算法示启,整體流程包含了4個(gè)模型兢哭,且通常訓(xùn)練過程中需要針對(duì)訓(xùn)練的actor model進(jìn)行采樣,因此訓(xùn)練起來夫嗓,穩(wěn)定性迟螺、效率廊谓、效果不易控制竿裂。

1) actor model/policy

model: 待訓(xùn)練的模型,通常是SFT訓(xùn)練后的模型作為初始化

2) reference

model: 參考模型链韭,也是經(jīng)SFT訓(xùn)練后的模型進(jìn)行初始化排霉,且通常與actor model是同一個(gè)模型窍株,且模型凍結(jié),不參與訓(xùn)練,其作用是在強(qiáng)化學(xué)習(xí)過程中球订,保障actor model與reference model的分布差異不宜過大后裸。

3) reward model:獎(jiǎng)勵(lì)模型,用于提供每個(gè)狀態(tài)或狀態(tài)動(dòng)作對(duì)的即時(shí)獎(jiǎng)勵(lì)信號(hào)辙售。

4) Critic model:作用是估計(jì)狀態(tài)或狀態(tài)動(dòng)作對(duì)的長(zhǎng)期價(jià)值轻抱,也稱為狀態(tài)值函數(shù)或動(dòng)作值函數(shù)。

3. DPO算法僅包含RLHF中的兩個(gè)模型旦部,即演員模型(actor

model)以及參考(reference model)祈搜,且訓(xùn)練過程中不需要進(jìn)行數(shù)據(jù)采樣。

4. RLHF可以參考附件中的引文

3. DPO的損失函數(shù)

DPO的損失函數(shù)

如何將RLHF的Reward model過程簡(jiǎn)化為上式士八,作者花了大量篇幅進(jìn)行了推導(dǎo)容燕,感興趣的讀者可以參考附件DPO的論文。

DPO算法的目的是最大化獎(jiǎng)勵(lì)模型(此處的獎(jiǎng)勵(lì)模型即為訓(xùn)練的策略)婚度,使得獎(jiǎng)勵(lì)模型對(duì)chosen和rejected數(shù)據(jù)的差值最大蘸秘,進(jìn)而學(xué)到人類偏好。

上式的后半部分通過對(duì)數(shù)函數(shù)運(yùn)算規(guī)則蝗茁,可以進(jìn)行如下轉(zhuǎn)化醋虏。

Loss公式轉(zhuǎn)化

轉(zhuǎn)化后的公式和源代碼中的計(jì)算函數(shù)中的公式是一致的。

其中左半部分是訓(xùn)練的policy模型選擇chosen優(yōu)先于rejected哮翘,右半部分是凍結(jié)的reference模型選擇chosen優(yōu)先于rejected颈嚼,二者的差值可類似于KL散度,保障actor模型的分布與reference模型的分布不會(huì)有較大的差異饭寺。

4. 微調(diào)流程

DPO微調(diào)流程

上圖展示了DPO微調(diào)的大致流程阻课,其中Trained

LM即為策略模型,F(xiàn)rozen LM即為參考模型艰匙,二者均是先進(jìn)行SFT微調(diào)得到的模型進(jìn)行初始化限煞,其中Trained LM需要進(jìn)行訓(xùn)練,F(xiàn)rozen LM不參與訓(xùn)練员凝。

兩個(gè)模型分別針對(duì)chosen和rejected進(jìn)行預(yù)測(cè)獲取對(duì)應(yīng)的得分署驻,再通過DPO的損失函數(shù)進(jìn)行損失計(jì)算,進(jìn)而不斷的迭代優(yōu)化健霹。

5. 源碼

源碼參考代碼:https://github.com/eric-mitchell/direct-preference-optimization

5.1 DPO損失函數(shù)


def preference_loss(policy_chosen_logps: torch.FloatTensor,


? policy_rejected_logps: torch.FloatTensor,


? reference_chosen_logps: torch.FloatTensor,


? reference_rejected_logps: torch.FloatTensor,

??????????????????? beta:? float,


? label_smoothing: float = 0.0,

??????????????????? ipo: bool? = False,


? reference_free: bool = False) -> Tuple[torch.FloatTensor,? torch.FloatTensor, torch.FloatTensor]:

??? # policy_chosen_logps:訓(xùn)練模型對(duì)于chosen經(jīng)過log后logits

???????? #? policy_rejected_logps:訓(xùn)練模型對(duì)于rejected經(jīng)過log后logits

???????? #? reference_chosen_logps:訓(xùn)練模型對(duì)于chosen經(jīng)過log后logits

???????? #? reference_rejected_logps:訓(xùn)練模型對(duì)于rejected經(jīng)過log后logits

???????? # beta: policy和reference的差異性控制參數(shù)


???????? # actor模型選擇chosen優(yōu)先于rejected

??? pi_logratios =? policy_chosen_logps - policy_rejected_logps

???????? # reference模型選擇chosen優(yōu)先于rejected

??? ref_logratios =? reference_chosen_logps - reference_rejected_logps


??? if reference_free:

??????? ref_logratios = 0


???????? #差值可類似于KL散度旺上,保障actor模型的分布與reference模型的分布不會(huì)有較大的差異

??? logits = pi_logratios -? ref_logratios? # also known as? h_{\pi_\theta}^{y_w,y_l}


??? if ipo:

??????? losses = (logits -? 1/(2 * beta)) ** 2? # Eq. 17 of? https://arxiv.org/pdf/2310.12036v2.pdf

??? else:

??????? # Eq. 3? https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7? of https://arxiv.org/pdf/2305.18290.pdf)

????????????????? #? label_smoothing為0,對(duì)應(yīng)的DPO論文的算法

??????? losses =? -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta *? logits) * label_smoothing


???????? # chosen和rejected的獎(jiǎng)勵(lì)

??? chosen_rewards = beta *? (policy_chosen_logps - reference_chosen_logps).detach()

??? rejected_rewards = beta? * (policy_rejected_logps - reference_rejected_logps).detach()


??? return losses,? chosen_rewards, rejected_rewards


5.2 批次訓(xùn)練過程


def get_batch_metrics(self, batch: Dict[str, Union[List,? torch.LongTensor]], loss_config: DictConfig, train=True):

???????? """Compute? the SFT or DPO loss and other metrics for the given batch of? inputs."""


???????? if loss_config.name? in {'dpo', 'ipo'}:

????????????????? # policy模型針對(duì)chosen和rejected進(jìn)行預(yù)測(cè)

????????????????? policy_chosen_logps,? policy_rejected_logps = self.concatenated_forward(self.policy, batch)

????????????????? with? torch.no_grad():

????????????????????????? #? reference模型針對(duì)chosen和rejected進(jìn)行預(yù)測(cè)

????????????????????????? reference_chosen_logps,? reference_rejected_logps = self.concatenated_forward(self.reference_model,? batch)


????????????????? if? loss_config.name == 'dpo':

????????????????????????? loss_kwargs? = {'beta': loss_config.beta, 'reference_free': loss_config.reference_free,? 'label_smoothing': loss_config.label_smoothing, 'ipo': False}

????????????????? elif? loss_config.name == 'ipo':

????????????????????????? loss_kwargs? = {'beta': loss_config.beta, 'ipo': True}

????????????????? else:

????????????????????????? raise? ValueError(f'unknown loss {loss_config.name}')

????????????????? #損失計(jì)算

????????????????? losses,? chosen_rewards, rejected_rewards = preference_loss(

????????????????????????? policy_chosen_logps,? policy_rejected_logps, reference_chosen_logps, reference_rejected_logps,? **loss_kwargs)


????????????????? reward_accuracies? = (chosen_rewards > rejected_rewards).float()


???????? elif? loss_config.name == 'sft':

????????????????? policy_chosen_logits? = self.policy(batch['chosen_input_ids'],? attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32)

????????????????? policy_chosen_logps? = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'],? average_log_prob=False)


????????????????? losses =? -policy_chosen_logps


???????? return losses.mean()


5.3 LM的交叉熵計(jì)算


def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor,? average_log_prob: bool = False) -> torch.FloatTensor:

??? #經(jīng)模型后的logits進(jìn)行批量計(jì)算logps


??? assert logits.shape[:-1]? == labels.shape


???????? #基于先前的token預(yù)測(cè)下一個(gè)token

??? labels = labels[:,? 1:].clone()

??? logits = logits[:, :-1,? :]

??? loss_mask = (labels !=? -100)


??? # dummy token; we'll? ignore the losses on these tokens later

??? labels[labels == -100] =? 0


???????? #交叉熵函數(shù)

??? per_token_logps =? torch.gather(logits.log_softmax(-1), dim=2,? index=labels.unsqueeze(2)).squeeze(2)


??? if average_log_prob:

??????? return? (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)

??? else:

??????? return? (per_token_logps * loss_mask).sum(-1)


5.4 其他注意

1. hugging face設(shè)置代理

源碼會(huì)從hugging face中下載英文語(yǔ)料和模型骤公,由于網(wǎng)絡(luò)限制,因此設(shè)置代理映射扬跋,將HF_ENDPOINT設(shè)置為https://hf-mirror.com阶捆,即設(shè)置: os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

2. 如果僅想要熟悉DPO整體流程,可以下載較小的生成式模型,如BLOOM 560M洒试,GPT2等

6. 總結(jié)

一句話足矣~

本文主要針對(duì)大語(yǔ)言模型的DPO算法的整理倍奢,包括原理、流程及部分源碼垒棋。

此外卒煞,建議大家可以針對(duì)源碼進(jìn)行運(yùn)行,源碼的歡迎大家一塊交流叼架。

7. 參考

(1) RLHF:https://blog.csdn.net/v_JULY_v/article/details/128579457

(2) DPO論文: https://arxiv.org/pdf/2305.18290v2.pdf

(3) DPO代碼: https://github.com/eric-mitchell/direct-preference-optimization

(4) DPO理解1:https://medium.com/@joaolages/direct-preference-optimization-dpo-622fc1f18707

(5) DPO理解2: https://zhuanlan.zhihu.com/p/669825918

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末畔裕,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子乖订,更是在濱河造成了極大的恐慌扮饶,老刑警劉巖,帶你破解...
    沈念sama閱讀 206,968評(píng)論 6 482
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件乍构,死亡現(xiàn)場(chǎng)離奇詭異甜无,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)哥遮,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,601評(píng)論 2 382
  • 文/潘曉璐 我一進(jìn)店門岂丘,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人眠饮,你說我怎么就攤上這事奥帘。” “怎么了君仆?”我有些...
    開封第一講書人閱讀 153,220評(píng)論 0 344
  • 文/不壞的土叔 我叫張陵翩概,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我返咱,道長(zhǎng)钥庇,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 55,416評(píng)論 1 279
  • 正文 為了忘掉前任咖摹,我火速辦了婚禮评姨,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘萤晴。我一直安慰自己吐句,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,425評(píng)論 5 374
  • 文/花漫 我一把揭開白布店读。 她就那樣靜靜地躺著嗦枢,像睡著了一般。 火紅的嫁衣襯著肌膚如雪屯断。 梳的紋絲不亂的頭發(fā)上文虏,一...
    開封第一講書人閱讀 49,144評(píng)論 1 285
  • 那天侣诺,我揣著相機(jī)與錄音,去河邊找鬼氧秘。 笑死年鸳,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的丸相。 我是一名探鬼主播搔确,決...
    沈念sama閱讀 38,432評(píng)論 3 401
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼灭忠!你這毒婦竟也來了膳算?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,088評(píng)論 0 261
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤更舞,失蹤者是張志新(化名)和其女友劉穎畦幢,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體缆蝉,經(jīng)...
    沈念sama閱讀 43,586評(píng)論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡宇葱,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,028評(píng)論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了刊头。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片黍瞧。...
    茶點(diǎn)故事閱讀 38,137評(píng)論 1 334
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖原杂,靈堂內(nèi)的尸體忽然破棺而出印颤,到底是詐尸還是另有隱情,我是刑警寧澤穿肄,帶...
    沈念sama閱讀 33,783評(píng)論 4 324
  • 正文 年R本政府宣布年局,位于F島的核電站,受9級(jí)特大地震影響咸产,放射性物質(zhì)發(fā)生泄漏矢否。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,343評(píng)論 3 307
  • 文/蒙蒙 一脑溢、第九天 我趴在偏房一處隱蔽的房頂上張望僵朗。 院中可真熱鬧,春花似錦屑彻、人聲如沸验庙。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,333評(píng)論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)粪薛。三九已至,卻和暖如春搏恤,著一層夾襖步出監(jiān)牢的瞬間违寿,已是汗流浹背让禀。 一陣腳步聲響...
    開封第一講書人閱讀 31,559評(píng)論 1 262
  • 我被黑心中介騙來泰國(guó)打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留陨界,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 45,595評(píng)論 2 355
  • 正文 我出身青樓痛阻,卻偏偏與公主長(zhǎng)得像菌瘪,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子阱当,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,901評(píng)論 2 345

推薦閱讀更多精彩內(nèi)容