File size: 4,191 Bytes
9dd96f2 3f8ef2d 9dd96f2 e651bb1 3f8ef2d 9dd96f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import os
import re
from collections import defaultdict
import numpy as np
import torch
from torch import nn
from contextlib import AbstractContextManager
# helper functions
def item(x):
return np.array(x).item()
def _prompt_to_parts(prompt, repeat=5):
# In order to allow easy formatting for prompts, we take string prompts
# in the format "[INST] [X] [/INST] Sure, I'll summarize this"
# and split them into a list of strings ["[INST]", 0, 0, 0, 0, 0, " [/INST] Sure, I'll summarize this"].
# Notice how each instance of [X] is replaced by multiple 0 placeholders (according to `~repeat`).
# This is in line with the SELFIE paper, where each interpreted token is inserted 5 times, probably to make
# the interpretation less likely to avoid it.
split_prompt = re.split(r' *\[X\]', prompt)
parts = []
for i in range(len(split_prompt)):
cur_part = split_prompt[i]
if cur_part != '':
# if we have multiple [X] in procession, there will be a '' between them in split_prompt
parts.append(cur_part)
if i < len(split_prompt) - 1:
parts.extend([0] * repeat)
print('Prompt parts:', parts)
return parts
class Hook(AbstractContextManager):
# Hook could be easily absorbed into SubstitutionHook instead, but I like it better to have them both.
# Seems like the right way from an aesthetic point of view.
def __init__(self, module, fn):
self.registered_hook = module.register_forward_hook(fn)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
def close(self):
self.registered_hook.remove()
class SubstitutionHook(Hook):
# This is where the substitution takes place, and it will be used by InterpretationPrompt later.
def __init__(self, module, positions_dict, values_dict):
assert set(positions_dict.keys()) == set(values_dict.keys())
keys = positions_dict.keys()
def fn(module, input, output):
device = output[0].device
dtype = output[0].dtype
for key in keys:
num_positions = len(positions_dict[key])
values = values_dict[key].unsqueeze(1).expand(-1, num_positions, -1) # batch_size x num_positions x hidden_dim
positions = positions_dict[key]
print(f'{positions=} {values.shape=} {output[0].shape=}')
output[0][:, positions, :] = values.to(dtype).to(device)
self.registered_hook.remove() # in generation with use_cache=True, after the first step the rest of the steps are one at a time
return output
self.registered_hook = module.register_forward_hook(fn)
# functions
class InterpretationPrompt:
def __init__(self, tokenizer, prompt, placeholder_token=' '):
prompt_parts = _prompt_to_parts(prompt)
if placeholder_token is None:
placeholder_token_id = tokenizer.eos_token_id
else:
placeholder_token_id = item(tokenizer.encode(placeholder_token, add_special_tokens=False))
assert placeholder_token_id != tokenizer.eos_token_id
self.tokens = []
self.placeholders = defaultdict(list)
for part in prompt_parts:
if type(part) == str:
self.tokens.extend(tokenizer.encode(part, add_special_tokens=False))
elif type(part) == int:
self.placeholders[part].append(len(self.tokens))
self.tokens.append(placeholder_token_id)
else:
raise NotImplementedError
def generate(self, model, embeds, k, layers_format='model.layers.{k}', **generation_kwargs):
num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]).to(model.device)
module = model.get_submodule(layers_format.format(k=k))
with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
generated = model.generate(tokens_batch, **generation_kwargs)
return generated
|