File size: 2,319 Bytes
e0b11c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d75dc6d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from transformers import GPTNeoXForCausalLM, AutoTokenizer

from .model_utils import Hack_no_grad
from .steers import Projected_Adaptor
from .model_base import LMSteerBase


class Switching_GPTNeoXModel(LMSteerBase):
    def __init__(self, model_name, adapted_component, adaptor_class,
                 num_steers, rank, epsilon, init_var,
                 low_resource_mode):
        super().__init__()
        self.adapted_component = adapted_component
        if low_resource_mode:
            self.model = GPTNeoXForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16, low_cpu_mem_usage=True
            )
        else:
            self.model = GPTNeoXForCausalLM.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.init_var = init_var
        self.num_steers = num_steers
        self.device = torch.device("cpu")
        embed_dim = self.model.embed_out.weight.shape[1]
        vocab_size = self.model.embed_out.weight.shape[0]
        self.low_resource_mode = low_resource_mode

        for _param in self.model.parameters():
            _param.requires_grad_(False)

        if adapted_component == "final_layer":
            self.model.gpt_neox = Hack_no_grad(self.model.gpt_neox)
            self.steer = Projected_Adaptor(
                self.model.embed_out, adaptor_class, num_steers, embed_dim,
                vocab_size, rank, epsilon, init_var, "output")
            self.model.set_output_embeddings(self.steer)
        else:
            raise NotImplementedError()

    def generate(self, prompt, steer_values, min_length=20, max_length=100,
                 seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
                 temperature=1, top_p=1):
        '''
        prompt: a string
        steer_values
        min_length: minimum generation length
        max_length: maximum generation length
        seed: seed for generation. None if not specified.
        '''
        return super().generate_low_resource(
            prompt, steer_values, min_length, max_length, seed,
            num_beams, num_beam_groups, do_sample, temperature, top_p)