File size: 4,191 Bytes
9dd96f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f8ef2d
9dd96f2
e651bb1
3f8ef2d
9dd96f2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import re
from collections import defaultdict
import numpy as np
import torch
from torch import nn
from contextlib import AbstractContextManager


# helper functions
def item(x):
    return np.array(x).item()

def _prompt_to_parts(prompt, repeat=5):
    # In order to allow easy formatting for prompts, we take string prompts
    # in the format "[INST] [X] [/INST] Sure, I'll summarize this" 
    # and split them into a list of strings ["[INST]", 0, 0, 0, 0, 0, " [/INST] Sure, I'll summarize this"].
    # Notice how each instance of [X] is replaced by multiple 0 placeholders (according to `~repeat`).
    # This is in line with the SELFIE paper, where each interpreted token is inserted 5 times, probably to make
    # the interpretation less likely to avoid it.
    
    split_prompt = re.split(r' *\[X\]', prompt)
    parts = []
    for i in range(len(split_prompt)):
        cur_part = split_prompt[i]
        if cur_part != '':
            # if we have multiple [X] in procession, there will be a '' between them in split_prompt
            parts.append(cur_part)
        if i < len(split_prompt) - 1:
            parts.extend([0] * repeat)
    print('Prompt parts:', parts)
    return parts


class Hook(AbstractContextManager):
    # Hook could be easily absorbed into SubstitutionHook instead, but I like it better to have them both.
    # Seems like the right way from an aesthetic point of view.
    def __init__(self, module, fn):
        self.registered_hook = module.register_forward_hook(fn)
    
    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()

    def close(self):
        self.registered_hook.remove()


class SubstitutionHook(Hook):
    # This is where the substitution takes place, and it will be used by InterpretationPrompt later.
    def __init__(self, module, positions_dict, values_dict):
        assert set(positions_dict.keys()) == set(values_dict.keys())
        keys = positions_dict.keys()
        
        def fn(module, input, output):
            device = output[0].device
            dtype = output[0].dtype
            
            for key in keys:
                num_positions = len(positions_dict[key])
                values = values_dict[key].unsqueeze(1).expand(-1, num_positions, -1)     # batch_size x num_positions x hidden_dim
                positions = positions_dict[key]
                print(f'{positions=} {values.shape=} {output[0].shape=}')
                output[0][:, positions, :] = values.to(dtype).to(device)
            self.registered_hook.remove() # in generation with use_cache=True, after the first step the rest of the steps are one at a time
            return output

        self.registered_hook = module.register_forward_hook(fn)
        
        
# functions 
class InterpretationPrompt:
    def __init__(self, tokenizer, prompt, placeholder_token=' '):
        prompt_parts = _prompt_to_parts(prompt)
        if placeholder_token is None:
            placeholder_token_id = tokenizer.eos_token_id
        else:
            placeholder_token_id = item(tokenizer.encode(placeholder_token, add_special_tokens=False))
            assert placeholder_token_id != tokenizer.eos_token_id
        self.tokens = []
        self.placeholders = defaultdict(list)
        for part in prompt_parts:
            if type(part) == str:
                self.tokens.extend(tokenizer.encode(part, add_special_tokens=False))
            elif type(part) == int:
                self.placeholders[part].append(len(self.tokens))
                self.tokens.append(placeholder_token_id)
            else:
                raise NotImplementedError

    def generate(self, model, embeds, k, layers_format='model.layers.{k}', **generation_kwargs):
        num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
        tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]).to(model.device)
        module = model.get_submodule(layers_format.format(k=k))
        with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
            generated = model.generate(tokens_batch, **generation_kwargs)
        return generated