ZJUPeng's picture
add continuous
d6682b6
# import torch
# from .utils import parent_module, brackets_to_periods
# import transformers
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# def euc(query, key):
# # Euclidean distance
# if len(key.shape) < 2:
# key = key.view(1, -1)
# return torch.cdist(key, query, p=2)
# def perturb_values(chosen_value, num_pert, device):
# # Create a bunch of noised versions of the value, then create batch, then train value
# chosen_value = chosen_value
# noise = torch.normal(0, 1, chosen_value.shape, device=device)
# noise[0] = noise[0]*0
# noise.requires_grad = True
# chosen_value = chosen_value + noise
# return chosen_value
# class GRACE(torch.nn.Module):
# def __init__(self, config, model, device):
# super(GRACE, self).__init__()
# self.config = config
# self.log_dict = {}
# self.model = model
# # self.tokenizer = model.tokenizer
# layer = config.inner_params[0]
# self.device = device
# # --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
# suffixes = [".weight", ".bias"]
# self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
# for n, p in self.model.named_parameters():
# p.requires_grad = False
# if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
# transpose = False
# else:
# transpose = True
# # --- Add GRACE to chosen layers ---
# edit_module = parent_module(self.model, brackets_to_periods(self.layer))
# layer_name = self.layer.rsplit(".", 1)[-1]
# original_layer = getattr(edit_module, layer_name)
# if type(original_layer) is not GRACEAdapter:
# setattr(edit_module, layer_name, GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
# def __call__(self, **kwargs):
# # if self.config.task == "hallucination":
# # print(kwargs)
# # key_id = (kwargs["labels"] == -100).sum() - 1
# # setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
# return self.model(**kwargs)
# def generate(self, *args, **kwargs):
# setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
# return self.model.generate(*args, **kwargs)
# def edit(self, config, tokens):
# key_id = (tokens["labels"] == -100).sum() - 1
# setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
# # --- pass edit label, training mode, and key_id into GRACE ---
# setattr(eval(f"self.model.{self.layer}"), "training", True)
# setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
# self.losses = []
# # --- train GRACE value ---
# for i in range(config.n_iter):
# # --- insert iteration into each layer (only initiate keys on iteration 1) ---
# setattr(eval(f"self.model.{self.layer}"), "iter", i)
# # --- pass tokens through model (including through the GRACE layer) ---
# outputs = self.model(**tokens)
# if i == 0:
# # --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
# optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
# loss = outputs.loss
# loss.backward()
# optimizer.step()
# optimizer.zero_grad()
# self.losses.append(loss.detach().cpu().numpy())
# self.loss = loss # Log final loss
# # --- pull out info we want to log from the GRACE layer ---
# setattr(eval(f"self.model.{self.layer}"), "training", False)
# chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
# nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
# self.log_dict["chosen_key"] = chosen_key
# self.log_dict["nkeys"] = nkeys
# class GRACEAdapter(torch.nn.Module):
# def __init__(self, config, layer, transpose):
# super(GRACEAdapter, self).__init__()
# self.layer = layer
# self.weight = self.layer.weight
# self.init_epsilon = config.eps
# self.dist_fn = config.dist_fn
# self.replacement = config.replacement
# self.device = layer.weight.device
# self.config = config
# self.num_pert = config.num_pert
# self.key_id = -1
# self.ensure_replace_token_loc = False
# if transpose:
# self.key_shape = layer.weight.shape[1]
# self.value_shape = layer.weight.shape[0]
# else:
# self.key_shape = layer.weight.shape[0]
# self.value_shape = layer.weight.shape[1]
# self.training = False
# def add_key(self, new_key, new_value):
# keys = torch.vstack([self.keys, new_key.detach()]) # Add new key to list of keys
# values = torch.nn.Parameter(torch.vstack([self.values, new_value]), requires_grad=True) # Add new value to list of values
# new_epsilon = torch.tensor(self.init_epsilon, device=self.device).view(1)
# epsilons = torch.vstack([self.epsilons, new_epsilon]) # Add new epsilon to list of epsilons
# key_labels = self.key_labels + [self.edit_label] # Add new key_label to list of key_labels
# return keys, values, epsilons, key_labels
# def init_key_value(self, query, value):
# key = query.detach()
# epsilon = torch.tensor(self.init_epsilon, device=self.device, requires_grad=False).view(1)
# key_label = [self.edit_label]
# return key, value, epsilon, key_label
# def label_match(self, edit_label, key_label):
# return edit_label.float().mean() == key_label.float().mean()
# def split_epsilons_in_half(self, nearest_key, smallest_distance):
# self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
# self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half
# def forward(self, *args):
# # Run layer forward and save what it would have returned for this instance
# layer_out = self.layer(*args)
# ### If training, we need to modify the codebook
# if (not self.training) & ('keys' not in self.__dict__):
# # If it's not training time and we haven't added any keys yet (this is before doing any editing)
# # print(self.__dict__)
# return layer_out
# else:
# if not self.training and not self.ensure_replace_token_loc and self.key_id == -1:
# token_to_edit = args[0].shape[1]-1
# self.key_id = args[0].shape[1]-1
# self.ensure_replace_token_loc = True
# else:
# token_to_edit = min(self.key_id, args[0].shape[1]-1) # args[0].shape[1] - 1 is sequence length
# query = args[0][:, token_to_edit, :] # Just use activation for last token
# if self.config.val_init == "cold":
# new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device))
# elif self.config.val_init == "warm":
# new_value = torch.nn.Parameter(layer_out[:, token_to_edit, :].detach(), requires_grad=True)
# if 'keys' not in self.__dict__:
# # If no keys exist, initialize keys, values, epsilons, and key labels
# self.keys, self.values, self.epsilons, self.key_labels = self.init_key_value(query, new_value)
# elif self.iter == 0:
# # Keys exist, so we have decide whether or not to update them (the fact that we've made it to this point means there was an error!)
# # --- search through keys for a match for query ---
# dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
# smallest_distance, nearest_key = dists.min(0)
# if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
# # If there's no close key, make a new key
# self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
# else:
# # If there is a close key, we need to handle conflicts
# if not self.label_match(self.edit_label, self.key_labels[nearest_key]):
# self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
# self.split_epsilons_in_half(nearest_key, smallest_distance)
# else:
# # If the current label is the SAME as the nearest label, just make the nearest epsilon bigger
# if smallest_distance > self.epsilons[nearest_key]:
# if self.config.eps_expand== "coverage":
# self.epsilons[nearest_key] = smallest_distance # Replace nearest epsilon with dist between old key and new key
# elif self.config.eps_expand == "moving_average":
# a = 0.5
# self.keys[nearest_key] = a*self.keys[nearest_key] + (1-a)*query # Move old key to be halfway between
# self.epsilons[nearest_key] = smallest_distance
# # self.epsilons[nearest_key] = smallest_distance + self.init_epsilon
# else:
# # If not iter 0, we don't need to change keys, we just need to learn the value
# pass
# # print(token_to_edit)
# # compute distance from query to all keys and find the closest keys
# dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
# smallest_dist, self.chosen_key = dists.min(0)
# smallest_dist = smallest_dist.view(-1, 1)
# chosen_value = self.values[self.chosen_key]
# eps = self.epsilons[self.chosen_key].view(-1, 1)
# if (self.config.val_train == "adv") and (self.training):
# chosen_value = perturb_values(chosen_value, self.num_pert, self.device)
# if self.replacement == "replace_all":
# layer_out = torch.where((smallest_dist <= eps).view(-1, 1, 1), chosen_value.unsqueeze(1).repeat_interleave(layer_out.shape[1], 1), layer_out)
# elif self.replacement == "replace_last":
# layer_out[:, token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, token_to_edit])
# elif self.replacement == "replace_prompt":
# layer_out[:, :token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, :token_to_edit])
# else:
# print("token replacement choice not found")
# return layer_out
import copy
import torch
from .utils import parent_module, brackets_to_periods
import transformers
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
def euc(query, key):
# Euclidean distance
if len(key.shape) < 2:
key = key.view(1, -1)
return torch.cdist(key, query, p=2)
def perturb_values(chosen_value, num_pert, device):
# Create a bunch of noised versions of the value, then create batch, then train value
chosen_value = chosen_value
noise = torch.normal(0, 1, chosen_value.shape, device=device)
noise[0] = noise[0] * 0
noise.requires_grad = True
chosen_value = chosen_value + noise
return chosen_value
class GRACE(torch.nn.Module):
def __init__(self, config, model, device):
super(GRACE, self).__init__()
self.config = config
self.log_dict = {}
self.model = model
self.config = config
# self.tokenizer = model.tokenizer
layer = config.inner_params[0]
self.device = device
self.original_layer = None
# --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
suffixes = [".weight", ".bias"]
self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
for n, p in self.model.named_parameters():
p.requires_grad = False
if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
transpose = False
else:
transpose = True
# --- Add GRACE to chosen layers ---
self.edit_module = parent_module(self.model, brackets_to_periods(self.layer))
self.layer_name = self.layer.rsplit(".", 1)[-1]
original_layer = getattr(self.edit_module, self.layer_name)
if type(original_layer) is not GRACEAdapter:
setattr(self.edit_module, self.layer_name,
GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
self.original_layer = copy.deepcopy(original_layer)
def __call__(self, **kwargs):
# if self.config.task == "hallucination":
# print(kwargs)
# key_id = (kwargs["labels"] == -100).sum() - 1
# setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
return self.model(**kwargs)
def get_adapter_layer(self):
adapter_layer = getattr(self.edit_module, self.layer_name)
assert type(adapter_layer) is GRACEAdapter, print('Adapter Layer is not added correctly....')
return adapter_layer
def reset_layer(self):
layer = getattr(self.edit_module, self.layer_name)
del layer
setattr(self.edit_module, self.layer_name, self.get_adapter_layer().original_layer)
def generate(self, *args, **kwargs):
setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
return self.model.generate(*args, **kwargs)
def edit(self, config, tokens):
key_id = (tokens["labels"] == -100).sum() - 1
setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
# --- pass edit label, training mode, and key_id into GRACE ---
setattr(eval(f"self.model.{self.layer}"), "training", True)
setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
self.losses = []
# --- train GRACE value ---
for i in range(config.n_iter):
# --- insert iteration into each layer (only initiate keys on iteration 1) ---
setattr(eval(f"self.model.{self.layer}"), "iter", i)
# --- pass tokens through model (including through the GRACE layer) ---
outputs = self.model(**tokens)
if i == 0:
# --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
loss = outputs.loss
try:
loss.backward()
optimizer.step()
optimizer.zero_grad()
self.losses.append(loss.detach().cpu().numpy())
except Exception as e:
pass
self.loss = loss # Log final loss
# --- pull out info we want to log from the GRACE layer ---
setattr(eval(f"self.model.{self.layer}"), "training", False)
chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
self.log_dict["chosen_key"] = chosen_key
self.log_dict["nkeys"] = nkeys
class GRACEAdapter(torch.nn.Module):
def __init__(self, config, layer, transpose):
super(GRACEAdapter, self).__init__()
self.layer = layer
self.original_layer = copy.deepcopy(self.layer)
self.weight = self.layer.weight
self.init_epsilon = config.eps
self.dist_fn = config.dist_fn
self.replacement = config.replacement
self.device = layer.weight.device
self.config = config
self.num_pert = config.num_pert
self.key_id = -1
if transpose:
self.key_shape = layer.weight.shape[1]
self.value_shape = layer.weight.shape[0]
else:
self.key_shape = layer.weight.shape[0]
self.value_shape = layer.weight.shape[1]
self.training = False
def add_key(self, new_key, new_value):
keys = torch.vstack([self.keys, new_key.detach()]) # Add new key to list of keys
values = torch.nn.Parameter(torch.vstack([self.values, new_value]),
requires_grad=True) # Add new value to list of values
new_epsilon = torch.tensor(self.init_epsilon, device=self.device).view(1)
epsilons = torch.vstack([self.epsilons, new_epsilon]) # Add new epsilon to list of epsilons
key_labels = self.key_labels + [self.edit_label] # Add new key_label to list of key_labels
return keys, values, epsilons, key_labels
def init_key_value(self, query, value):
key = query.detach()
epsilon = torch.tensor(self.init_epsilon, device=self.device, requires_grad=False).view(1)
key_label = [self.edit_label]
return key, value, epsilon, key_label
def label_match(self, edit_label, key_label):
return edit_label.float().mean() == key_label.float().mean()
def split_epsilons_in_half(self, nearest_key, smallest_distance):
self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half
def forward(self, *args):
# Run layer forward and save what it would have returned for this instance
layer_out = self.layer(*args)
### If training, we need to modify the codebook
if (not self.training) & ('keys' not in self.__dict__):
# If it's not training time and we haven't added any keys yet (this is before doing any editing)
# print(self.__dict__)
return layer_out
else:
if not self.training:
if self.key_id == -1:
token_to_edit = args[0].shape[1] - 1
self.key_id = args[0].shape[1] - 1
else:
token_to_edit = min(self.key_id, args[0].shape[1] - 1)
else:
token_to_edit = min(self.key_id, args[0].shape[1] - 1) # args[0].shape[1] - 1 is sequence length
query = args[0][:, token_to_edit, :] # Just use activation for last token
if self.config.val_init == "cold":
new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device))
elif self.config.val_init == "warm":
new_value = torch.nn.Parameter(layer_out[:, token_to_edit, :].detach(), requires_grad=True)
if 'keys' not in self.__dict__:
# If no keys exist, initialize keys, values, epsilons, and key labels
self.keys, self.values, self.epsilons, self.key_labels = self.init_key_value(query, new_value)
elif self.iter == 0:
# Keys exist, so we have decide whether or not to update them (the fact that we've made it to this point means there was an error!)
# --- search through keys for a match for query ---
dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
smallest_distance, nearest_key = dists.min(0)
if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
# If there's no close key, make a new key
self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
else:
# If there is a close key, we need to handle conflicts
if not self.label_match(self.edit_label, self.key_labels[nearest_key]):
self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
self.split_epsilons_in_half(nearest_key, smallest_distance)
else:
# If the current label is the SAME as the nearest label, just make the nearest epsilon bigger
if smallest_distance > self.epsilons[nearest_key]:
if self.config.eps_expand == "coverage":
self.epsilons[
nearest_key] = smallest_distance # Replace nearest epsilon with dist between old key and new key
elif self.config.eps_expand == "moving_average":
a = 0.5
self.keys[nearest_key] = a * self.keys[nearest_key] + (
1 - a) * query # Move old key to be halfway between
self.epsilons[nearest_key] = smallest_distance
# self.epsilons[nearest_key] = smallest_distance + self.init_epsilon
else:
# If not iter 0, we don't need to change keys, we just need to learn the value
pass
# print(token_to_edit)
# compute distance from query to all keys and find the closest keys
dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
smallest_dist, self.chosen_key = dists.min(0)
smallest_dist = smallest_dist.view(-1, 1)
chosen_value = self.values[self.chosen_key]
eps = self.epsilons[self.chosen_key].view(-1, 1)
if (self.config.val_train == "adv") and (self.training):
chosen_value = perturb_values(chosen_value, self.num_pert, self.device)
if self.replacement == "replace_all":
layer_out = torch.where((smallest_dist <= eps).view(-1, 1, 1),
chosen_value.unsqueeze(1).repeat_interleave(layer_out.shape[1], 1), layer_out)
elif self.replacement == "replace_last":
layer_out[:, token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, token_to_edit])
elif self.replacement == "replace_prompt":
layer_out[:, :token_to_edit] = torch.where((smallest_dist <= eps), chosen_value,
layer_out[:, :token_to_edit])
else:
print("token replacement choice not found")
return layer_out