ZoRA-Refusal-Suppression / exl2_wrapper.py
llmixer's picture
Added generator code
4783804 verified
raw
history blame
No virus
2.32 kB
import os
import torch
from safetensors import safe_open
class ExLlamaV2ModuleWrapper:
@classmethod
def wrap(cls, model, load = True):
for idx, module in enumerate(model.modules):
if idx == 0 or idx >= (len(model.modules) - 2):
continue
model.modules[idx] = ExLlamaV2ModuleWrapper(model, module, idx)
if not load:
return
suppress_dir_file = os.path.join(model.config.model_dir, 'suppress_dir.safetensors')
if os.path.exists(suppress_dir_file):
print(f'Loading suppress direction file "{suppress_dir_file}"')
with safe_open(suppress_dir_file, framework='pt', device='cpu') as f:
model._suppress_dir = []
for layer in range(len(f.keys())):
model._suppress_dir.append(f.get_tensor(f'_suppress_dir_{layer}'))
else:
print(f'No suppress direction file, not wrapping. Tried to load: "{suppress_dir_file}"')
return
def __init__(self, model, module, idx):
if not hasattr(model, '_suppress_dir'):
model._suppress_dir = None
if not hasattr(model, '_residual'):
model._residual = None
self.model = model
self.module = module
self.idx = idx
def __getattribute__(self, name):
if name == 'forward':
return object.__getattribute__(self, 'wrapped_forward')
try:
return getattr(object.__getattribute__(self, 'module'), name)
except AttributeError:
pass
return object.__getattribute__(self, name)
def suppress(self, x):
if self.model._suppress_dir is not None:
r = self.model._suppress_dir[self.idx - 2].clone().to(x.device)
r = r.view(-1, 1)
proj_scalar = torch.matmul(x, r)
proj = proj_scalar * r.transpose(0, 1)
x = x - proj
return x
def wrapped_forward(self, *args, **kwargs):
if self.model._residual is not None:
if len(self.model._residual) < self.idx and args[0].shape[1] == 1:
self.model._residual.append(args[0].clone().to('cpu'))
x = self.suppress(args[0])
x = self.module.forward(*((x,) + args[1:]), **kwargs)
return self.suppress(x)