import os import torch from safetensors import safe_open class ExLlamaV2ModuleWrapper: @classmethod def wrap(cls, model, load = True): for idx, module in enumerate(model.modules): if idx == 0 or idx >= (len(model.modules) - 2): continue model.modules[idx] = ExLlamaV2ModuleWrapper(model, module, idx) if not load: return suppress_dir_file = os.path.join(model.config.model_dir, 'suppress_dir.safetensors') if os.path.exists(suppress_dir_file): print(f'Loading suppress direction file "{suppress_dir_file}"') with safe_open(suppress_dir_file, framework='pt', device='cpu') as f: model._suppress_dir = [] for layer in range(len(f.keys())): model._suppress_dir.append(f.get_tensor(f'_suppress_dir_{layer}')) else: print(f'No suppress direction file, not wrapping. Tried to load: "{suppress_dir_file}"') return def __init__(self, model, module, idx): if not hasattr(model, '_suppress_dir'): model._suppress_dir = None if not hasattr(model, '_residual'): model._residual = None self.model = model self.module = module self.idx = idx def __getattribute__(self, name): if name == 'forward': return object.__getattribute__(self, 'wrapped_forward') try: return getattr(object.__getattribute__(self, 'module'), name) except AttributeError: pass return object.__getattribute__(self, name) def suppress(self, x): if self.model._suppress_dir is not None: r = self.model._suppress_dir[self.idx - 2].clone().to(x.device) r = r.view(-1, 1) proj_scalar = torch.matmul(x, r) proj = proj_scalar * r.transpose(0, 1) x = x - proj return x def wrapped_forward(self, *args, **kwargs): if self.model._residual is not None: if len(self.model._residual) < self.idx and args[0].shape[1] == 1: self.model._residual.append(args[0].clone().to('cpu')) x = self.suppress(args[0]) x = self.module.forward(*((x,) + args[1:]), **kwargs) return self.suppress(x)