強(qiáng)化學(xué)習(xí)框架TRL源碼——談?wù)凱PO和RLOO的異同

PPO(Proximal Policy Optimization)是rlhf經(jīng)典算法,RLOO (REINFORCE Leave One-Out) 則是基于 PPO 改進(jìn)的算法假消,TRL分別提供了PPOTrainerRLOOTrainer的實現(xiàn)酸役。下面我們分析下二者的異同什湘。

1.關(guān)于模型

PPO需要加載四個模型:1) 策略模型(policy model)耿导,2) 參考策略模型(reference policy model)骑祟,3) 獎勵模型(reward model)回懦,以及 4) 價值模型(value model),而RLOO沒有4) 價值模型(value model)次企,只有其他三個模型怯晕。所以從顯存來說RLOO肯定比PPO更省。

PPO將policy和value兩個模型包裹在一起缸棵,不僅前饋的時候二者都有輸出舟茶,而且在訓(xùn)練的時候兩個模型也會同時進(jìn)行訓(xùn)練。

PPOTrainer
class PolicyAndValueWrapper(nn.Module):
    def __init__(self, policy, value_model) -> None:
        super().__init__()
        self.policy = policy
        self.value_model = value_model
        self.critic_backbone = getattr(value_model, value_model.base_model_prefix)

    def forward(self, **kwargs):
        output = self.critic_backbone(
            **kwargs,
        )
        logits = self.value_model.score(output.hidden_states[-1])
        return self.policy(**kwargs), logits

2.計算Reward

兩種方法的獎勵reward都包含了環(huán)境獎勵堵第,即reward model的輸出和KL散度約束懲罰吧凉,但二者的計算方式不同。


PPO和RLOO reward計算模式對比

從上圖我們可以看出踏志,PPO在計算獎勵的時候?qū)⒚總€補(bǔ)全 token 視為單獨的動作阀捅,但只有EOS token獲得真正的獎勵(score),輸出格式為[batch_size, seq_len]针余。

PPOTrainer
# 4. compute rewards
kl = logprobs - ref_logprobs
non_score_reward = -args.kl_coef * kl
rewards = non_score_reward.clone()
actual_start = torch.arange(rewards.size(0), device=rewards.device)
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
rewards[[actual_start, actual_end]] += scores

而 RLOO 將整個補(bǔ)全視為單一動作饲鄙, EOS 獎勵歸因于整個補(bǔ)全。因此RLOO rewards的格式是[batch_size, 1]圆雁。

RLOOTrainer
# 4. compute rewards
kl = logprobs - ref_logprobs
non_score_reward = (-args.kl_coef * kl).sum(1)
rlhf_reward = scores + non_score_reward

3.計算Advantage

在PPO算法里面傍妒,優(yōu)勢函數(shù)=動作價值函數(shù)-狀態(tài)價值函數(shù),即A(s, a) = Q(s, a) - V(s)摸柄。優(yōu)勢函數(shù)advantage是通過泛化優(yōu)勢估計算法(GAE)得來的颤练,同時可以計算得到動作價值函數(shù)return。

PPOTrainer
# 6. compute advantages and returns
lastgaelam = 0
advantages_reversed = []
gen_length = responses.shape[1]
for t in reversed(range(gen_length)):
    nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
    delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
    lastgaelam = delta + args.gamma * args.lam * lastgaelam
    advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], axis=1)
returns = advantages + values
advantages = masked_whiten(advantages, ~padding_mask)
advantages = torch.masked_fill(advantages, padding_mask, 0)

而在RLOO里面,優(yōu)勢函數(shù)=總獎勵-虛擬基線嗦玖。虛擬基線是多次采樣后的除了該采樣本身的平均獎勵患雇,這也是Leave One-Out的由來。該采樣的獎勵-其他平均采樣的獎勵宇挫,和基于該動作的價值-所有動作的平均價值在理論上是一致的苛吱。這里的rloo_k是指總采樣次數(shù)。

RLOOTrainer
# vectorized RLOO advantages implementation
rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
advantages = rlhf_reward - baseline
advantages = advantages.flatten()

4.計算Loss

首先兩種方法在計算policy model loss的時候都使用了clip方法器瘪。


policy model loss計算公式

PPO除此之外還會計算value model loss


value model loss計算公式

下面是PPO的流程圖翠储,可以看出policy model和value model都會進(jìn)行訓(xùn)練。
PPO算法流程圖
PPOTrainer
vf_losses1 = torch.square(vpred - mb_return)
vf_losses2 = torch.square(vpredclipped - mb_return)
vf_loss_max = torch.max(vf_losses1, vf_losses2)
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
vf_clipfrac = masked_mean(
    (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
loss = pg_loss + args.vf_coef * vf_loss

而RLOO只計算policy model的loss橡疼。

RLOOTrainer
new_ratio = (new_logprobs - mb_logprobs).exp()
new_logprobs = new_logprobs.sum(1)
mb_logprobs = mb_logprobs.sum(1)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = pg_loss_max.mean()
loss = pg_loss
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末援所,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子欣除,更是在濱河造成了極大的恐慌住拭,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,204評論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件历帚,死亡現(xiàn)場離奇詭異滔岳,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)挽牢,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,091評論 3 395
  • 文/潘曉璐 我一進(jìn)店門谱煤,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人禽拔,你說我怎么就攤上這事刘离。” “怎么了奏赘?”我有些...
    開封第一講書人閱讀 164,548評論 0 354
  • 文/不壞的土叔 我叫張陵寥闪,是天一觀的道長太惠。 經(jīng)常有香客問我磨淌,道長,這世上最難降的妖魔是什么凿渊? 我笑而不...
    開封第一講書人閱讀 58,657評論 1 293
  • 正文 為了忘掉前任梁只,我火速辦了婚禮,結(jié)果婚禮上埃脏,老公的妹妹穿的比我還像新娘搪锣。我一直安慰自己,他們只是感情好彩掐,可當(dāng)我...
    茶點故事閱讀 67,689評論 6 392
  • 文/花漫 我一把揭開白布构舟。 她就那樣靜靜地躺著,像睡著了一般堵幽。 火紅的嫁衣襯著肌膚如雪狗超。 梳的紋絲不亂的頭發(fā)上弹澎,一...
    開封第一講書人閱讀 51,554評論 1 305
  • 那天,我揣著相機(jī)與錄音努咐,去河邊找鬼苦蒿。 笑死,一個胖子當(dāng)著我的面吹牛渗稍,可吹牛的內(nèi)容都是我干的佩迟。 我是一名探鬼主播,決...
    沈念sama閱讀 40,302評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼竿屹,長吁一口氣:“原來是場噩夢啊……” “哼报强!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起羔沙,我...
    開封第一講書人閱讀 39,216評論 0 276
  • 序言:老撾萬榮一對情侶失蹤躺涝,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后扼雏,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體坚嗜,經(jīng)...
    沈念sama閱讀 45,661評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,851評論 3 336
  • 正文 我和宋清朗相戀三年诗充,在試婚紗的時候發(fā)現(xiàn)自己被綠了苍蔬。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 39,977評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡蝴蜓,死狀恐怖碟绑,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情茎匠,我是刑警寧澤格仲,帶...
    沈念sama閱讀 35,697評論 5 347
  • 正文 年R本政府宣布,位于F島的核電站诵冒,受9級特大地震影響凯肋,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜汽馋,卻給世界環(huán)境...
    茶點故事閱讀 41,306評論 3 330
  • 文/蒙蒙 一侮东、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧豹芯,春花似錦悄雅、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,898評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春容诬,著一層夾襖步出監(jiān)牢的瞬間围辙,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,019評論 1 270
  • 我被黑心中介騙來泰國打工放案, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留姚建,地道東北人。 一個月前我還...
    沈念sama閱讀 48,138評論 3 370
  • 正文 我出身青樓吱殉,卻偏偏與公主長得像掸冤,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子友雳,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,927評論 2 355

推薦閱讀更多精彩內(nèi)容