File size: 2,318 Bytes
4783804 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import os
import torch
from safetensors import safe_open
class ExLlamaV2ModuleWrapper:
@classmethod
def wrap(cls, model, load = True):
for idx, module in enumerate(model.modules):
if idx == 0 or idx >= (len(model.modules) - 2):
continue
model.modules[idx] = ExLlamaV2ModuleWrapper(model, module, idx)
if not load:
return
suppress_dir_file = os.path.join(model.config.model_dir, 'suppress_dir.safetensors')
if os.path.exists(suppress_dir_file):
print(f'Loading suppress direction file "{suppress_dir_file}"')
with safe_open(suppress_dir_file, framework='pt', device='cpu') as f:
model._suppress_dir = []
for layer in range(len(f.keys())):
model._suppress_dir.append(f.get_tensor(f'_suppress_dir_{layer}'))
else:
print(f'No suppress direction file, not wrapping. Tried to load: "{suppress_dir_file}"')
return
def __init__(self, model, module, idx):
if not hasattr(model, '_suppress_dir'):
model._suppress_dir = None
if not hasattr(model, '_residual'):
model._residual = None
self.model = model
self.module = module
self.idx = idx
def __getattribute__(self, name):
if name == 'forward':
return object.__getattribute__(self, 'wrapped_forward')
try:
return getattr(object.__getattribute__(self, 'module'), name)
except AttributeError:
pass
return object.__getattribute__(self, name)
def suppress(self, x):
if self.model._suppress_dir is not None:
r = self.model._suppress_dir[self.idx - 2].clone().to(x.device)
r = r.view(-1, 1)
proj_scalar = torch.matmul(x, r)
proj = proj_scalar * r.transpose(0, 1)
x = x - proj
return x
def wrapped_forward(self, *args, **kwargs):
if self.model._residual is not None:
if len(self.model._residual) < self.idx and args[0].shape[1] == 1:
self.model._residual.append(args[0].clone().to('cpu'))
x = self.suppress(args[0])
x = self.module.forward(*((x,) + args[1:]), **kwargs)
return self.suppress(x)
|