PPO(Proximal Policy Optimization)是rlhf經(jīng)典算法,RLOO (REINFORCE Leave One-Out) 則是基于 PPO 改進(jìn)的算法假消,TRL分別提供了PPOTrainer和RLOOTrainer的實現(xiàn)酸役。下面我們分析下二者的異同什湘。
1.關(guān)于模型
PPO需要加載四個模型:1) 策略模型(policy model)耿导,2) 參考策略模型(reference policy model)骑祟,3) 獎勵模型(reward model)回懦,以及 4) 價值模型(value model),而RLOO沒有4) 價值模型(value model)次企,只有其他三個模型怯晕。所以從顯存來說RLOO肯定比PPO更省。
PPO將policy和value兩個模型包裹在一起缸棵,不僅前饋的時候二者都有輸出舟茶,而且在訓(xùn)練的時候兩個模型也會同時進(jìn)行訓(xùn)練。
PPOTrainer
class PolicyAndValueWrapper(nn.Module):
def __init__(self, policy, value_model) -> None:
super().__init__()
self.policy = policy
self.value_model = value_model
self.critic_backbone = getattr(value_model, value_model.base_model_prefix)
def forward(self, **kwargs):
output = self.critic_backbone(
**kwargs,
)
logits = self.value_model.score(output.hidden_states[-1])
return self.policy(**kwargs), logits
2.計算Reward
兩種方法的獎勵reward都包含了環(huán)境獎勵堵第,即reward model的輸出和KL散度約束懲罰吧凉,但二者的計算方式不同。
從上圖我們可以看出踏志,PPO在計算獎勵的時候?qū)⒚總€補(bǔ)全 token 視為單獨的動作阀捅,但只有EOS token獲得真正的獎勵(score),輸出格式為[batch_size, seq_len]针余。
PPOTrainer
# 4. compute rewards
kl = logprobs - ref_logprobs
non_score_reward = -args.kl_coef * kl
rewards = non_score_reward.clone()
actual_start = torch.arange(rewards.size(0), device=rewards.device)
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
rewards[[actual_start, actual_end]] += scores
而 RLOO 將整個補(bǔ)全視為單一動作饲鄙, EOS 獎勵歸因于整個補(bǔ)全。因此RLOO rewards的格式是[batch_size, 1]圆雁。
RLOOTrainer
# 4. compute rewards
kl = logprobs - ref_logprobs
non_score_reward = (-args.kl_coef * kl).sum(1)
rlhf_reward = scores + non_score_reward
3.計算Advantage
在PPO算法里面傍妒,優(yōu)勢函數(shù)=動作價值函數(shù)-狀態(tài)價值函數(shù),即A(s, a) = Q(s, a) - V(s)摸柄。優(yōu)勢函數(shù)advantage是通過泛化優(yōu)勢估計算法(GAE)得來的颤练,同時可以計算得到動作價值函數(shù)return。
PPOTrainer
# 6. compute advantages and returns
lastgaelam = 0
advantages_reversed = []
gen_length = responses.shape[1]
for t in reversed(range(gen_length)):
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
lastgaelam = delta + args.gamma * args.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], axis=1)
returns = advantages + values
advantages = masked_whiten(advantages, ~padding_mask)
advantages = torch.masked_fill(advantages, padding_mask, 0)
而在RLOO里面,優(yōu)勢函數(shù)=總獎勵-虛擬基線嗦玖。虛擬基線是多次采樣后的除了該采樣本身的平均獎勵患雇,這也是Leave One-Out的由來。該采樣的獎勵-其他平均采樣的獎勵宇挫,和基于該動作的價值-所有動作的平均價值在理論上是一致的苛吱。這里的rloo_k是指總采樣次數(shù)。
RLOOTrainer
# vectorized RLOO advantages implementation
rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
advantages = rlhf_reward - baseline
advantages = advantages.flatten()
4.計算Loss
首先兩種方法在計算policy model loss的時候都使用了clip方法器瘪。
PPO除此之外還會計算value model loss
下面是PPO的流程圖翠储,可以看出policy model和value model都會進(jìn)行訓(xùn)練。
PPOTrainer
vf_losses1 = torch.square(vpred - mb_return)
vf_losses2 = torch.square(vpredclipped - mb_return)
vf_loss_max = torch.max(vf_losses1, vf_losses2)
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
vf_clipfrac = masked_mean(
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
loss = pg_loss + args.vf_coef * vf_loss
而RLOO只計算policy model的loss橡疼。
RLOOTrainer
new_ratio = (new_logprobs - mb_logprobs).exp()
new_logprobs = new_logprobs.sum(1)
mb_logprobs = mb_logprobs.sum(1)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = pg_loss_max.mean()
loss = pg_loss