import copy import random import torch from torch.nn import functional as F from .utils import parent_module, brackets_to_periods, EarlyStopMeter, EditingMeanAct import transformers import numpy as np from torch import Tensor from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN from .merge import slerp, GTA, linear import torch.nn as nn import gc merge_dict = { 'slerp': slerp(), 'ties': GTA('magnitude', 'sum', normalize=True), 'magnitude_norm': GTA('magnitude', None, normalize=True), 'magnitude': GTA('magnitude', None, normalize=False), 'sign': GTA(None, 'sum', normalize=True), 'dare_ties': GTA('rescaled_random', 'sum'), 'dare_linear': GTA('random', None), 'linear': linear() } edit_history = [] merge_group_edit_history = [] def euc(query, key, config, act_mask=None, infer=False): # Euclidean distance act_fn = ACT2FN[config.hidden_act] l2_norm = torch.norm(act_fn(key) - act_fn(query), dim=-1) if infer and l2_norm.size(1) > 100: topk = torch.topk(l2_norm, k=1, largest=True) return topk.values.mean() if act_mask is not None: return torch.sum(l2_norm * act_mask, dim=1) / torch.sum(act_mask, dim=1) else: return torch.mean(l2_norm, dim=-1) class WISE(torch.nn.Module): def __init__(self, config, model, device): super(WISE, self).__init__() self.config = config self.model = model self.config = config if hasattr(self.model.config, 'hidden_act'): self.config.hidden_act = self.model.config.hidden_act elif hasattr(self.model.config, 'activation_function'): self.config.hidden_act = self.model.config.activation_function # self.tokenizer = model.tokenizer layer = config.inner_params[0] self.device = device self.adapter_layer = None self.original_layer = None # --- ensure proper formatting (WISE edits weights matrices) --- suffixes = [".weight", ".bias"] self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer for n, p in self.model.named_parameters(): p.requires_grad = False if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel): conv1D = True else: conv1D = False # --- Add WISE to chosen layers --- self.edit_module = parent_module(self.model, brackets_to_periods(self.layer)) self.layer_name = self.layer.rsplit(".", 1)[-1] adapter_layer = getattr(self.edit_module, self.layer_name) if type(adapter_layer) is not WISEAdapter: setattr(self.edit_module, self.layer_name, WISEAdapter(config, adapter_layer, conv1D=conv1D)) self.original_layer = copy.deepcopy(adapter_layer) print(f"New weights successfully inserted into {layer}") gc.collect() torch.cuda.empty_cache() gc.collect() # Forward def __call__(self, **kwargs): if not self.config.retrieve: if hasattr(self.get_adapter_layer(), 'editing') and not self.get_adapter_layer().editing: # final merge if not self.get_adapter_layer().original_layer.weight.equal(self.get_adapter_layer().new_weight) and self.get_adapter_layer().editing_total_cnt >= self.config.save_freq: self.get_adapter_layer().memory_weight.append(self.get_adapter_layer().new_weight) if len(self.get_adapter_layer().memory_weight) > 0 and self.get_adapter_layer().editing_total_cnt >= self.config.save_freq: print('length of memory is ', len(self.get_adapter_layer().memory_weight), '!!!!!!') self.get_adapter_layer().merge_weight() return self.model(**kwargs) def reset_layer(self): layer = getattr(self.edit_module, self.layer_name) del layer setattr(self.edit_module, self.layer_name, self.get_adapter_layer().original_layer) def get_adapter_layer(self): adapter_layer = getattr(self.edit_module, self.layer_name) assert type(adapter_layer) is WISEAdapter, print('Adapter Layer is not added correctly....') return adapter_layer # TODO: generation def generate(self, *args, **kwargs): setattr(eval(f"self.model.{self.layer}"), "key_id", -1) return self.model.generate(*args, **kwargs) def edit(self, config, tokens, act_mask=None, deact_mask=None): # for retrieve ## global edit_history global merge_group_edit_history edit_history.append([{f"{k1}" : v1.to('cpu') for k1, v1 in tokens.items()}, False]) # for retrieve ## last_prompt_token_loc = (tokens["labels"] == -100).sum(dim=-1) - 1 setattr(eval(f"self.model.{self.layer}"), "training", True) setattr(eval(f"self.model.{self.layer}"), "editing", True) self.get_adapter_layer().set_parameter_tunable() if getattr(eval(f"self.model.{self.layer}"), "editing_total_cnt") % self.config.save_freq == 0: self.get_adapter_layer().generate_activation_mask(self.config.mask_ratio) # --- train Wise value --- loss_meter = EarlyStopMeter() for i in range(config.n_iter): if i == 0: # --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) --- optimizer = torch.optim.SGD([self.get_adapter_layer().new_weight], config.edit_lr, weight_decay=1e-5) ft_loss = self.__cal_ft_loss(tokens, last_prompt_token_loc) act_loss = self.__cal_activation_loss(self.get_adapter_layer().original_layer_output, self.get_adapter_layer().new_weight_layer_output, config=config, act_mask=act_mask, deact_mask=deact_mask) loss = ft_loss + act_loss.to(ft_loss.device) if loss_meter.stop(): self.get_adapter_layer().save_editing_activation() # add last gradient break if i == config.n_iter - 1: self.get_adapter_layer().save_editing_activation() # add last gradient if self.config.retrieve and self.get_adapter_layer().merge_cnt > 0 and self.config.replay: memory_loss = [] for _ in merge_group_edit_history: idx = 0 while True: memo_input, is_used = _[idx] if not is_used: _[idx][1] = True break idx += 1 if idx == len(_): ## re Assign for m in range(len(_)): _[m][1] = False idx = 0 memo_input = {f"{k1}" : v1.to(self.config.device) for k1, v1 in memo_input.items()} self.model(**memo_input) memory_act_loss = self.__cal_memory_neg_activation_loss(self.get_adapter_layer().original_layer_output, self.get_adapter_layer().new_weight_layer_output, config=config, act_mask=act_mask, deact_mask=deact_mask) memory_loss.append(memory_act_loss.to(ft_loss.device)) del memo_input neg_memo_loss = torch.stack(memory_loss).mean() loss += neg_memo_loss if len(edit_history) > 0: memo_input = random.choice(edit_history)[0] memo_input = {f"{k1}" : v1.to(self.config.device) for k1, v1 in memo_input.items()} self.model(**memo_input) pos_memo_loss = self.__cal_memory_pos_activation_loss(self.get_adapter_layer().original_layer_output, self.get_adapter_layer().new_weight_layer_output, config=config, act_mask=act_mask, deact_mask=deact_mask) del memo_input loss += pos_memo_loss.to(ft_loss.device) # for replay Appendix B.3 optimizer.zero_grad() loss.backward() self.get_adapter_layer().mask_new_weight_gradient() if self.config.retrieve and self.get_adapter_layer().merge_cnt > 0 and self.config.replay: print( f"loss {np.round(loss.item(), 3)} = {np.round(ft_loss.item(), 3)} + {np.round(act_loss.item(), 3)} + {np.round(neg_memo_loss.item(), 3)} + {np.round(pos_memo_loss.item(), 3)}" ) else: print( f"loss {np.round(loss.item(), 3)} = {np.round(ft_loss.item(), 3)} + {np.round(act_loss.item(), 3)}" ) optimizer.step() loss_meter.update(loss.item()) if type(self.config.norm_constraint) is float: self.__norm_constraint(self.config.norm_constraint) # --- pull out info we want to log from the Wise layer --- setattr(eval(f"self.model.{self.layer}"), "editing", False) setattr(eval(f"self.model.{self.layer}"), "training", False) editing_total_cnt = getattr(eval(f"self.model.{self.layer}"), "editing_total_cnt") + 1 setattr(eval(f"self.model.{self.layer}"), "editing_total_cnt", editing_total_cnt) # if self.config.save_freq is not None and editing_total_cnt % self.config.save_freq == 0: self.get_adapter_layer().save_weight() print(f'Add New Weight to Memory...') if editing_total_cnt % self.config.merge_freq == 0: # for retrieve ## merge_group_edit_history.append(edit_history) edit_history = [] # for retrieve ## self.get_adapter_layer().merge_weight() print(f'Merge Weight of (New, Original) Matrix... with {self.config.merge_alg}') def __norm_constraint(self, norm_constraint): new_weight = self.get_adapter_layer().new_weight original_weight = self.get_adapter_layer().weight with torch.no_grad(): new_weight[...] = torch.clamp( new_weight, min=original_weight - norm_constraint, max=original_weight + norm_constraint ) def __cal_ft_loss(self, tokens, last_prompt_token_loc): k = 1 bs = tokens["input_ids"].shape[0] - k logits = self.model(**tokens).logits shift_logits = logits[:-k, :-1, :].contiguous() shift_labels = tokens['labels'][:-k, 1:].contiguous() label_mask = torch.zeros_like(shift_labels, dtype=torch.bool) for i, col_index in enumerate(last_prompt_token_loc[:-k]): label_mask[i, col_index-1:] = True shift_labels[~label_mask] = -100 log_probs = -nn.functional.log_softmax(shift_logits, dim=-1) if shift_labels.dim() == log_probs.dim() - 1: shift_labels = shift_labels.unsqueeze(-1) padding_mask = shift_labels.eq(-100) # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask # will ignore them in any case. shift_labels = torch.clamp(shift_labels, min=0) nll_loss = log_probs.gather(dim=-1, index=shift_labels) nll_loss.masked_fill_(padding_mask, 0.0) num_active_elements = padding_mask.numel() - padding_mask.long().sum() nll_loss = nll_loss.sum() / num_active_elements return nll_loss # loss_fct = CrossEntropyLoss(reduction='none') # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # loss = loss.view(bs, -1) # label_mask = torch.zeros_like(loss, dtype=torch.bool) # for i, col_index in enumerate(last_prompt_token_loc[:-k]): # label_mask[i, col_index - 1:] = True # ft_loss = ((loss * label_mask).sum(1) / label_mask.sum(1)).mean() # return ft_loss def __cal_activation_loss(self, original_layer_output, new_weight_layer_output, config=None, act_mask=None, deact_mask=None): k = 1 if act_mask is not None: in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config, act_mask=act_mask) out_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config, act_mask=deact_mask) else: in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config) out_scope_dist = euc(original_layer_output[-k:, ...], new_weight_layer_output[-k:, ...], config) loss = out_scope_dist.view(-1,1) - in_scope_dist + config.gamma loss2 = out_scope_dist - config.alpha loss3 = config.beta - in_scope_dist loss3 = torch.mean(loss3[loss3 > 0]) if min(loss3[loss3 > 0].size()) > 0 else torch.tensor(0.).to(original_layer_output.device) loss2 = torch.mean(loss2[loss2 > 0]) if min(loss2[loss2 > 0].size()) > 0 else torch.tensor(0.).to(original_layer_output.device) loss = torch.mean(loss[loss > 0]) if min(loss[loss > 0].size()) > 0 else torch.tensor(0.).to(original_layer_output.device) return loss + loss2 + loss3 def __cal_memory_pos_activation_loss(self, original_layer_output, new_weight_layer_output, config=None, act_mask=None, deact_mask=None): k = 1 in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config) loss4 = 20 - in_scope_dist return torch.mean(loss4[loss4 > 0]) if min(loss4[loss4 > 0].size()) > 0 else torch.tensor(0.) def __cal_memory_neg_activation_loss(self, original_layer_output, new_weight_layer_output, config=None, act_mask=None, deact_mask=None): k = 1 in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config) loss4 = in_scope_dist - 5 return torch.mean(loss4[loss4 > 0]) if min(loss4[loss4 > 0].size()) > 0 else torch.tensor(0.) class WISEAdapter(torch.nn.Module): def __init__(self, config, layer, conv1D): super(WISEAdapter, self).__init__() self.layer = layer self.weight = self.layer.weight self.device = layer.weight.device self.config = config self.new_weight = copy.deepcopy(self.weight) self.original_layer = copy.deepcopy(self.layer) self.memory_weight = [] self.memory_mean_act = [] self.merge_cnt = 0 # only for retrieve assert not self.weight.requires_grad, print('Original Layer can not be tunable....') self.used_mask = None self.training = False self.editing = False self.conv1D = conv1D self.editing_mean_act = EditingMeanAct() self.editing_total_cnt = 0 def set_parameter_tunable(self): self.new_weight.requires_grad = True def save_weight(self): self.memory_weight.append(copy.deepcopy(self.new_weight)) self.new_weight = copy.deepcopy(self.original_layer.weight) if self.config.retrieve: self.memory_mean_act.append(copy.deepcopy(self.editing_mean_act)) self.editing_mean_act = EditingMeanAct() def merge_weight(self): if self.config.save_freq is not None: # for ties dare dare_ties if not self.config.retrieve: merge_alg = merge_dict[self.config.merge_alg] if self.original_layer.weight.equal(self.layer.weight): cur_new_weight = merge_alg.execute([self.config.weights / len(self.memory_weight) for _ in range(len(self.memory_weight))], self.original_layer.weight, self.memory_weight, densities=self.config.densities) else: cur_new_weight = merge_alg.execute([0.4 / len(self.memory_weight) for _ in range(len(self.memory_weight))] + [0.6], self.original_layer.weight, self.memory_weight + [self.layer.weight], densities=self.config.densities) self.layer.weight = torch.nn.Parameter(cur_new_weight.to(self.layer.weight.device), requires_grad=False) self.new_weight = copy.deepcopy(self.original_layer.weight) del self.memory_weight self.memory_weight = [] else: merge_alg = merge_dict[self.config.merge_alg] merge_num = self.config.merge_freq // self.config.save_freq assert len(self.memory_weight) >= merge_num new_merge_weight = merge_alg.execute([self.config.weights / merge_num for _ in range(merge_num)], self.original_layer.weight, self.memory_weight[-merge_num:], densities=self.config.densities) min_a = 1e9 for _ in range(merge_num): self.memory_weight.pop() edit_act = self.memory_mean_act.pop() min_a = min(min_a, edit_act.min_act()) self.new_weight = copy.deepcopy(self.original_layer.weight) self.memory_weight.append(new_merge_weight) self.memory_mean_act.append(EditingMeanAct(min_a=min_a)) print(len(self.memory_weight)) assert len(self.memory_mean_act) == len(self.memory_weight) self.merge_cnt += 1 else: merge_alg = merge_dict[self.config.merge_alg] cur_new_weight = merge_alg.execute(0.5, self.layer.weight, [self.new_weight], densities=self.config.densities) self.layer.weight = torch.nn.Parameter(cur_new_weight.to(self.layer.weight.device), requires_grad=False) self.new_weight = copy.deepcopy(self.original_layer.weight) def save_editing_activation(self): in_scope_dist = euc(self.original_layer_output[:-1, ...], self.new_weight_layer_output[:-1, ...], self.config) self.editing_mean_act.update(in_scope_dist.mean().item()) def generate_activation_mask(self, mask_ratio): p_grad = self.new_weight.reshape(-1) p_mask = np.random.choice([1, 0], size=p_grad.size()[0], p=[mask_ratio, 1 - mask_ratio]) p_mask = torch.from_numpy(p_mask).to(p_grad.device) self.weight_mask = p_mask def generate_non_overlapping_mask(self, mask_ratio): p_grad = self.new_weight.reshape(-1) mask_size = int(mask_ratio * p_grad.size()[0]) if self.used_mask is None: self.used_mask = np.zeros(p_grad.size()[0], dtype=bool) available_indices = np.where(~self.used_mask)[0] # 获取未被遮罩的元素索引 if len(available_indices) < mask_size: raise ValueError("Not enough unused elements to generate a new mask.") chosen_indices = np.random.choice(available_indices, size=mask_size, replace=False) mask_array = np.zeros(p_grad.size()[0], dtype=int) mask_array[chosen_indices] = 1 self.used_mask[chosen_indices] = True # 更新遮罩状态 self.weight_mask = torch.from_numpy(mask_array).to(p_grad.device) def new_weight_forward(self, input: Tensor, weight) -> Tensor: if self.conv1D: size_out = input.size()[:-1] + (weight.size(1),) input = torch.addmm(self.original_layer.bias, input.view(-1, input.size(-1)), weight) input = input.view(size_out) return input else: return F.linear(input, weight) def mask_new_weight_gradient(self): assert self.new_weight.grad is not None, print('Gradient Collection for New Weight error, gradient not found') # Add gradient mask after the loss updates p_size = self.new_weight.grad.size() p_grad = self.new_weight.grad.reshape(-1) # mask = torch.from_numpy(np.random.choice([0, 1], size=p_grad.size()[0], p=[.1, .9])).cuda() p_grad = p_grad * self.weight_mask self.new_weight.grad = p_grad.view(p_size).to(self.new_weight.grad.dtype) def forward(self, *args): if self.editing: layer_out = self.new_weight_forward(*args, self.new_weight) self.new_weight_layer_output = layer_out self.original_layer_output = self.original_layer(*args) else: if not self.config.retrieve: original_layer_output = self.original_layer(*args) layer_output = self.layer(*args) new_weight_layer_output = self.new_weight_forward(*args, self.new_weight) dist2 = euc(original_layer_output, new_weight_layer_output, self.config, infer=True) dist1 = euc(original_layer_output, layer_output, self.config, infer=True) threshold = self.editing_mean_act.min_act() * self.config.act_ratio if dist1.item() < threshold and dist2.item() < threshold: layer_out = original_layer_output elif dist1.item() > dist2.item(): layer_out = layer_output else: layer_out = new_weight_layer_output else: original_layer_output = self.original_layer(*args) new_weight_layer_output = self.new_weight_forward(*args, self.new_weight) dist1 = euc(original_layer_output, new_weight_layer_output, self.config, infer=True) threshold = self.editing_mean_act.min_act() * self.config.act_ratio min_dist = dist1 if min_dist.item() < threshold: layer_out = original_layer_output else: layer_out = new_weight_layer_output for i in range(len(self.memory_weight)): memory_retrieve_weight = self.memory_weight[i] memory_weight_layer_output = self.new_weight_forward(*args, memory_retrieve_weight) dist = euc(original_layer_output, memory_weight_layer_output, self.config, infer=True) if dist > min_dist and dist > self.memory_mean_act[i].min_act() * self.config.act_ratio: layer_out = memory_weight_layer_output min_dist = dist print(dist, self.memory_mean_act[i].min_act() * self.config.act_ratio) return layer_out