|
""" |
|
Loss Implementation based upon |
|
https://github.com/eric-mitchell/direct-preference-optimization |
|
""" |
|
|
|
import logging |
|
from typing import Any, KeysView |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
__all__ = ["Losses"] |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class DPOLoss(nn.Module): |
|
""" |
|
Implements |
|
"Direct Preference Optimization: |
|
Your Language Model is Secretly a Reward Model" |
|
from https://arxiv.org/abs/2305.18290 |
|
""" |
|
|
|
def __init__(self, cfg: Any): |
|
super().__init__() |
|
self.cfg = cfg |
|
|
|
def forward( |
|
self, |
|
policy_chosen_logps: torch.FloatTensor, |
|
policy_rejected_logps: torch.FloatTensor, |
|
reference_chosen_logps: torch.FloatTensor, |
|
reference_rejected_logps: torch.FloatTensor, |
|
): |
|
pi_logratios = policy_chosen_logps - policy_rejected_logps |
|
ref_logratios = reference_chosen_logps - reference_rejected_logps |
|
|
|
losses = self.get_losses(logits=pi_logratios - ref_logratios) |
|
chosen_rewards = ( |
|
self.cfg.training.beta |
|
* (policy_chosen_logps - reference_chosen_logps).detach() |
|
) |
|
rejected_rewards = ( |
|
self.cfg.training.beta |
|
* (policy_rejected_logps - reference_rejected_logps).detach() |
|
) |
|
|
|
return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean() |
|
|
|
def get_losses(self, logits): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
label_smoothing = 0 |
|
|
|
losses = ( |
|
-F.logsigmoid(self.cfg.training.beta * logits) * (1 - label_smoothing) |
|
- F.logsigmoid(-self.cfg.training.beta * logits) * label_smoothing |
|
) |
|
return losses |
|
|
|
|
|
class KTOPairLoss(nn.Module): |
|
""" |
|
Implements original paired KTO implementation |
|
Adopted from https://github.com/ContextualAI/HALOs |
|
and https://github.com/huggingface/trl |
|
""" |
|
|
|
def __init__(self, cfg: Any): |
|
super().__init__() |
|
self.cfg = cfg |
|
|
|
def forward( |
|
self, |
|
policy_chosen_logps: torch.FloatTensor, |
|
policy_rejected_logps: torch.FloatTensor, |
|
reference_chosen_logps: torch.FloatTensor, |
|
reference_rejected_logps: torch.FloatTensor, |
|
): |
|
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) |
|
rejected_KL = ( |
|
(policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) |
|
) |
|
|
|
chosen_logratios = policy_chosen_logps - reference_chosen_logps |
|
rejected_logratios = policy_rejected_logps - reference_rejected_logps |
|
losses = torch.cat( |
|
( |
|
1 |
|
- F.sigmoid(self.cfg.training.beta * (chosen_logratios - rejected_KL)), |
|
1 |
|
- F.sigmoid(self.cfg.training.beta * (chosen_KL - rejected_logratios)), |
|
), |
|
0, |
|
) |
|
|
|
chosen_rewards = ( |
|
self.cfg.training.beta |
|
* (policy_chosen_logps - reference_chosen_logps).detach() |
|
).float() |
|
rejected_rewards = ( |
|
self.cfg.training.beta |
|
* (policy_rejected_logps - reference_rejected_logps).detach() |
|
).float() |
|
|
|
return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean() |
|
|
|
|
|
class HingeLoss(DPOLoss): |
|
def get_losses(self, logits): |
|
losses = torch.relu(1 - self.cfg.training.beta * logits) |
|
return losses |
|
|
|
|
|
class IPOLoss(DPOLoss): |
|
""" |
|
Implements "A General Theoretical Paradigm |
|
to Understand Learning from Human Preferences" |
|
from https://arxiv.org/pdf/2310.12036.pdf |
|
""" |
|
|
|
def get_losses(self, logits): |
|
|
|
|
|
|
|
losses = (logits - 1 / (2 * self.cfg.training.beta)) ** 2 |
|
return losses |
|
|
|
|
|
class Losses: |
|
"""Losses factory.""" |
|
|
|
_losses = { |
|
"DPOLoss": DPOLoss, |
|
"HingeLoss": HingeLoss, |
|
"IPOLoss": IPOLoss, |
|
"KTOPairLoss": KTOPairLoss, |
|
} |
|
|
|
@classmethod |
|
def names(cls) -> KeysView: |
|
return cls._losses.keys() |
|
|
|
@classmethod |
|
def get(cls, name: str) -> Any: |
|
"""Access to Losses. |
|
Args: |
|
name: losses name |
|
Returns: |
|
A class to build the Losses |
|
""" |
|
return cls._losses.get(name, DPOLoss) |
|
|
|
|
|
|
|
LOSS_REDUCTION = { |
|
"DPOLoss": False, |
|
"KTOPairLoss": False, |
|
"HingeLoss": True, |
|
"IPOLoss": True, |
|
} |
|
|