|
```python |
|
from torch.nn import nn |
|
|
|
BASE_MODEL = "CarperAI/stable-vicuna-13b-delta" |
|
RM_PATH = "vicuna-v0-rm.pt" |
|
|
|
|
|
class GPTRewardModel(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL) |
|
self.config = model.config |
|
self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd |
|
self.transformer = model.model |
|
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False) |
|
self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
|
self.PAD_ID = self.tokenizer.pad_token_id |
|
|
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
past_key_values=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
mc_token_ids=None, |
|
labels=None, |
|
return_dict=False, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
): |
|
loss = None |
|
transformer_outputs = self.transformer( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
hidden_states = transformer_outputs[0] |
|
|
|
rewards = self.v_head(hidden_states).squeeze(-1) |
|
end_scores = [] |
|
bs = input_ids.shape[0] |
|
|
|
loss = 0 |
|
inference = False |
|
|
|
for i in range(bs): |
|
c_inds = (input_ids[i] == self.PAD_ID).nonzero() |
|
c_ind = c_inds[0].item() if len(c_inds) > 0 else input_ids.shape[1] |
|
end_scores.append(rewards[i, c_ind - 1]) |
|
|
|
chosen_end_scores = torch.stack(end_scores) |
|
return {"end_scores": chosen_end_scores} |
|
|
|
|
|
rw_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
|
rw_tokenizer.padding_side = "right" |
|
rw_model = GPTRewardModel() |
|
rw_model.load_state_dict(torch.load(RM_PATH)['module']) |
|
rw_model.half() |
|
rw_model.eval() |
|
|
|
|
|
def get_scores(samples: List[str]): |
|
scores_list = [] |
|
batch_size = 2 |
|
for i in range(0, len(samples), batch_size): |
|
sub_samples = samples[i : i + batch_size] |
|
sub_samples = [chosen for chosen in sub_samples] |
|
encodings_dict = rw_tokenizer( |
|
sub_samples, |
|
truncation=True, |
|
max_length=config.train.seq_length, |
|
padding="max_length", |
|
return_tensors="pt", |
|
) |
|
input_ids = encodings_dict["input_ids"].to(rw_device) |
|
attn_masks = encodings_dict["attention_mask"].to(rw_device) |
|
with torch.no_grad(): |
|
sub_scores = rw_model(input_ids=input_ids, attention_mask=attn_masks) |
|
scores_list.append(sub_scores["end_scores"]) |
|
scores = torch.cat(scores_list, dim=0) |
|
return scores |
|
``` |