import copy import re import torch import util class FineTunedModel(torch.nn.Module): def __init__(self, model, modules, frozen_modules=[] ): super().__init__() if isinstance(modules, str): modules = [modules] self.model = model self.ft_modules = {} self.orig_modules = {} util.freeze(self.model) for module_name, module in model.named_modules(): for ft_module_regex in modules: match = re.search(ft_module_regex, module_name) if match is not None: ft_module = copy.deepcopy(module) self.orig_modules[module_name] = module self.ft_modules[module_name] = ft_module util.unfreeze(ft_module) print(f"=> Finetuning {module_name}") for ft_module_name, module in ft_module.named_modules(): ft_module_name = f"{module_name}.{ft_module_name}" for freeze_module_name in frozen_modules: match = re.search(freeze_module_name, ft_module_name) if match: print(f"=> Freezing {ft_module_name}") util.freeze(module) self.ft_modules_list = torch.nn.ModuleList(self.ft_modules.values()) self.orig_modules_list = torch.nn.ModuleList(self.orig_modules.values()) @classmethod def from_checkpoint(cls, model, checkpoint, frozen_modules=[]): if isinstance(checkpoint, str): checkpoint = torch.load(checkpoint) modules = [f"{key}$" for key in list(checkpoint.keys())] ftm = FineTunedModel(model, modules, frozen_modules=frozen_modules) ftm.load_state_dict(checkpoint) return ftm def __enter__(self): for key, ft_module in self.ft_modules.items(): util.set_module(self.model, key, ft_module) def __exit__(self, exc_type, exc_value, tb): for key, module in self.orig_modules.items(): util.set_module(self.model, key, module) def parameters(self): parameters = [] for ft_module in self.ft_modules.values(): parameters.extend(list(ft_module.parameters())) return parameters def state_dict(self): state_dict = {key: module.state_dict() for key, module in self.ft_modules.items()} return state_dict def load_state_dict(self, state_dict): for key, sd in state_dict.items(): self.ft_modules[key].load_state_dict(sd)