|
import os |
|
import torch |
|
from safetensors import safe_open |
|
|
|
class ExLlamaV2ModuleWrapper: |
|
@classmethod |
|
def wrap(cls, model, load = True): |
|
for idx, module in enumerate(model.modules): |
|
if idx == 0 or idx >= (len(model.modules) - 2): |
|
continue |
|
model.modules[idx] = ExLlamaV2ModuleWrapper(model, module, idx) |
|
|
|
if not load: |
|
return |
|
|
|
suppress_dir_file = os.path.join(model.config.model_dir, 'suppress_dir.safetensors') |
|
if os.path.exists(suppress_dir_file): |
|
print(f'Loading suppress direction file "{suppress_dir_file}"') |
|
with safe_open(suppress_dir_file, framework='pt', device='cpu') as f: |
|
model._suppress_dir = [] |
|
for layer in range(len(f.keys())): |
|
model._suppress_dir.append(f.get_tensor(f'_suppress_dir_{layer}')) |
|
else: |
|
print(f'No suppress direction file, not wrapping. Tried to load: "{suppress_dir_file}"') |
|
return |
|
|
|
def __init__(self, model, module, idx): |
|
if not hasattr(model, '_suppress_dir'): |
|
model._suppress_dir = None |
|
if not hasattr(model, '_residual'): |
|
model._residual = None |
|
self.model = model |
|
self.module = module |
|
self.idx = idx |
|
|
|
def __getattribute__(self, name): |
|
if name == 'forward': |
|
return object.__getattribute__(self, 'wrapped_forward') |
|
|
|
try: |
|
return getattr(object.__getattribute__(self, 'module'), name) |
|
except AttributeError: |
|
pass |
|
return object.__getattribute__(self, name) |
|
|
|
def suppress(self, x): |
|
if self.model._suppress_dir is not None: |
|
r = self.model._suppress_dir[self.idx - 2].clone().to(x.device) |
|
r = r.view(-1, 1) |
|
proj_scalar = torch.matmul(x, r) |
|
proj = proj_scalar * r.transpose(0, 1) |
|
x = x - proj |
|
return x |
|
|
|
def wrapped_forward(self, *args, **kwargs): |
|
if self.model._residual is not None: |
|
if len(self.model._residual) < self.idx and args[0].shape[1] == 1: |
|
self.model._residual.append(args[0].clone().to('cpu')) |
|
x = self.suppress(args[0]) |
|
x = self.module.forward(*((x,) + args[1:]), **kwargs) |
|
return self.suppress(x) |
|
|