vicuna-13b-v0-rm / README.md
pvduy's picture
Create README.md
51803d1
from torch.nn import nn

BASE_MODEL = "CarperAI/stable-vicuna-13b-delta"
RM_PATH = "vicuna-v0-rm.pt"


class GPTRewardModel(nn.Module):
    def __init__(self):
        super().__init__()
        model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
        self.config = model.config
        self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
        self.transformer = model.model
        self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
        self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
        self.PAD_ID = self.tokenizer.pad_token_id


    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
        labels=None,
        return_dict=False,
        output_attentions=False,
        output_hidden_states=False,
    ):
        loss = None
        transformer_outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
        )

        hidden_states = transformer_outputs[0]

        rewards = self.v_head(hidden_states).squeeze(-1)
        end_scores = []
        bs = input_ids.shape[0]

        loss = 0
        inference = False
        
        for i in range(bs):
          c_inds = (input_ids[i] == self.PAD_ID).nonzero()
          c_ind = c_inds[0].item() if len(c_inds) > 0 else input_ids.shape[1]
          end_scores.append(rewards[i, c_ind - 1])

        chosen_end_scores = torch.stack(end_scores)
        return {"end_scores": chosen_end_scores}


rw_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
rw_tokenizer.padding_side = "right"
rw_model = GPTRewardModel()
rw_model.load_state_dict(torch.load(RM_PATH)['module'])
rw_model.half()
rw_model.eval()


def get_scores(samples: List[str]):
    scores_list = []
    batch_size = 2
    for i in range(0, len(samples), batch_size):
        sub_samples = samples[i : i + batch_size]
        sub_samples = [chosen for chosen in sub_samples]
        encodings_dict = rw_tokenizer(
            sub_samples,
            truncation=True,
            max_length=config.train.seq_length,
            padding="max_length",
            return_tensors="pt",
        )
        input_ids = encodings_dict["input_ids"].to(rw_device)
        attn_masks = encodings_dict["attention_mask"].to(rw_device)
        with torch.no_grad():
            sub_scores = rw_model(input_ids=input_ids, attention_mask=attn_masks)
        scores_list.append(sub_scores["end_scores"])
    scores = torch.cat(scores_list, dim=0)
    return scores