File size: 4,260 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import logging
from typing import Any, Dict
import torch
from torch import nn
from transformers import AutoModelForCausalLM
from llm_studio.src.metrics.text_causal_language_modeling_metrics import Perplexity
from llm_studio.src.utils.data_utils import batch_padding
from llm_studio.src.utils.modeling_utils import (
create_nlp_backbone,
generate,
prepare_lora,
)
logger = logging.getLogger(__name__)
class ValueHead(nn.Module):
"""
The ValueHead class implements a head for GPT2 that returns a scalar for each
output token.
Based on the implementation of trl library:
https://github.com/lvwerra/trl/blob/main/trl/models/modeling_value_head.py
"""
def __init__(self, config):
super().__init__()
if not hasattr(config, "summary_dropout_prob"):
summary_dropout_prob = 0.1
else:
summary_dropout_prob = config.summary_dropout_prob
self.dropout = (
nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity()
)
# some models such as OPT have a projection layer before the word embeddings
# e.g. OPT-350m
if hasattr(config, "word_embed_proj_dim"):
hidden_size = config.word_embed_proj_dim
else:
hidden_size = config.hidden_size
self.summary = nn.Linear(hidden_size, 1)
def forward(self, hidden_states):
output = self.dropout(hidden_states)
# For now force upcast in fp32 if needed. Let's keep the
# output in fp32 for numerical stability.
if output.dtype != self.summary.weight.dtype:
output = output.to(self.summary.weight.dtype)
output = self.summary(output)
return output
class Model(nn.Module):
"""
Model for causal language modeling problem type.
"""
def __init__(self, cfg: Any):
"""
Args:
cfg: config with all the hyperparameters
"""
super(Model, self).__init__()
self.cfg = cfg
assert cfg.training.lora, "LoRA must be True for RLHF"
self.backbone, self.backbone_config = create_nlp_backbone(
cfg, model_class=AutoModelForCausalLM
)
self.backbone = prepare_lora(cfg=self.cfg, backbone=self.backbone)
if self.cfg.prediction.metric == "Perplexity":
self.perplexity = Perplexity(self.cfg, reduce=False)
self.value_head = ValueHead(self.backbone_config)
self.value_head.summary.bias.data.zero_()
def forward(
self,
batch: Dict,
padding: bool = True,
) -> Dict:
# disable cache if gradient checkpointing is enabled
if self.cfg.architecture.gradient_checkpointing:
self.backbone.config.use_cache = False
outputs: Dict = {}
mask_key = "attention_mask"
pad_keys = [
"input_ids",
"attention_mask",
"special_tokens_mask",
"labels",
]
if padding:
batch = batch_padding(
self.cfg,
batch,
self.training,
mask_key=mask_key,
pad_keys=pad_keys,
)
output = self.backbone(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
output_hidden_states=True,
)
if self.cfg.prediction.metric == "Perplexity" and not self.training:
outputs["perplexity"] = self.perplexity(output.logits, batch["labels"])
if self.training:
last_hidden_state = output.hidden_states[-1]
# force upcast in fp32 if logits are in half-precision
if output.logits.dtype != torch.float32:
output.logits = output.logits.float()
outputs["logits"] = output.logits
outputs["value"] = self.value_head(last_hidden_state).squeeze(-1)
# enable cache again if gradient checkpointing is enabled
if self.cfg.architecture.gradient_checkpointing:
self.backbone.config.use_cache = True
return outputs
def generate(self, batch: Dict, cfg: Any, streamer=None):
return generate(self.backbone, batch, cfg, streamer)
|