|
""" |
|
Helpers to load checkpoints for learned feature maps (attentions) or other parameters |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
from omegaconf import OmegaConf |
|
|
|
from src.utils.logging import print_header, _format_arg |
|
from .convert_model import convert_attention |
|
from .peft import create_peft_config |
|
|
|
|
|
def load_and_convert_attns(model: nn.Module, |
|
model_config: dict, |
|
attention_type: str = None, |
|
checkpoint_path: str = None, |
|
print_model: bool = False, |
|
merge_loras: bool = False, |
|
train_converted: bool = True, |
|
peft_gradient_checkpointing: bool = None, |
|
train_attention: bool = False, |
|
freeze_weights: bool = True, |
|
rank: int = 0, |
|
remove_base_attn: bool = True, |
|
) -> nn.Module: |
|
""" |
|
Load trained attention kernel parameter weights |
|
""" |
|
if freeze_weights: |
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
|
|
if attention_type is not None: |
|
model_config['attention']['attention_type'] = attention_type |
|
model_config['attention']['rank'] = rank |
|
|
|
model = convert_attention(model, model_config['attention'], |
|
train_attention, remove_base_attn) |
|
|
|
|
|
peft_key = 'peft' |
|
if 'peft_config' in model_config['attention']: |
|
peft_key = 'peft_config' |
|
if peft_key in model_config['attention']: |
|
peft_config = model_config['attention'][peft_key] |
|
model, peft_config = create_peft_config(model, peft_config, |
|
model_config['model']['torch_dtype'], |
|
preserve_requires_grad=train_converted, |
|
use_gradient_checkpointing=peft_gradient_checkpointing) |
|
else: |
|
peft_config = None |
|
|
|
if print_model and rank == 0: |
|
print_header('*** Model before checkpoint load ***') |
|
print(model) |
|
|
|
|
|
if checkpoint_path is not None: |
|
print(f'Loading weights from {checkpoint_path}...') |
|
state_dict = torch.load(checkpoint_path)['model_state_dict'] |
|
_keys = model.load_state_dict(state_dict, strict=False) |
|
try: |
|
assert len(_keys.unexpected_keys) == 0 |
|
if rank == 0: |
|
print_header('*** All expected keys matched successfully ***') |
|
if print_model: |
|
for k in state_dict.keys(): |
|
print(k) |
|
except Exception as e: |
|
if rank == 0: |
|
print(e) |
|
print_header('*** Error: unexpected keys in checkpoint ***') |
|
print('Unexpected keys:') |
|
for k in _keys.unexpected_keys: |
|
print(k) |
|
if print_model and rank == 0: |
|
print_header('*** Model ***') |
|
print(model) |
|
if merge_loras: |
|
model = model.merge_and_unload() |
|
if print_model and rank == 0: |
|
print_header('*** Model (after merging adapters) ***') |
|
print(model) |
|
if print_model and rank == 0: |
|
print_header('*** Trainable Parameters ***') |
|
for n, p in model.named_parameters(): |
|
if p.requires_grad: |
|
print(f'βββ {n} (dtype = {p.dtype})') |
|
return model, peft_config |
|
|
|
|
|
def load_and_convert_finetune(model: nn.Module, |
|
finetune_config: dict, |
|
checkpoint_path: str = None, |
|
print_model: bool = False, |
|
merge_loras: bool = False, |
|
peft_gradient_checkpointing: bool = None, |
|
rank: int = 0, |
|
**peft_kwargs: any): |
|
""" |
|
Load trained adapter / model weights |
|
""" |
|
|
|
peft_config = None |
|
if finetune_config.finetune.method == 'lora': |
|
if getattr(finetune_config.finetune, 'kwargs', None) is not None: |
|
model, peft_config = create_peft_config( |
|
model, finetune_config.finetune, |
|
use_gradient_checkpointing=peft_gradient_checkpointing, |
|
**peft_kwargs, |
|
) |
|
|
|
if 'trainable_weights' in finetune_config.finetune: |
|
for name in finetune_config.finetune['trainable_weights']: |
|
for n, p in model.named_parameters(): |
|
if name in n: |
|
p.requires_grad = True |
|
else: |
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
|
|
if 'trainable_weights' in finetune_config.finetune: |
|
for name in finetune_config.finetune['trainable_weights']: |
|
for n, p in model.named_parameters(): |
|
if name in n: |
|
if 'layers_to_ignore' in finetune_config.finetune: |
|
layer = int(n.split('layers.')[-1].split('.')[0]) |
|
if layer not in finetune_config.finetune['layers_to_ignore']: |
|
p.requires_grad = True |
|
else: |
|
p.requires_grad = True |
|
|
|
|
|
|
|
if checkpoint_path: |
|
state_dict = torch.load(checkpoint_path)['model_state_dict'] |
|
_keys = model.load_state_dict(state_dict, strict=False) |
|
try: |
|
assert len(_keys.unexpected_keys) == 0 |
|
if rank == 0: |
|
print_header('*** All expected keys matched successfully ***') |
|
except Exception as e: |
|
if rank == 0: |
|
print(e) |
|
print_header('*** Error: unexpected keys in checkpoint ***') |
|
print('Unexpected keys:') |
|
for k in _keys.unexpected_keys: |
|
print(k) |
|
|
|
if print_model and rank == 0: |
|
print_header('*** Model ***') |
|
print(model) |
|
|
|
if merge_loras: |
|
try: |
|
model = model.merge_and_unload() |
|
if print_model and rank == 0: |
|
print_header('*** Model (after merging adapters) ***') |
|
print(model) |
|
except Exception as e: |
|
print(e) |
|
|
|
if print_model and rank == 0: |
|
print_header('*** Trainable Parameters ***') |
|
count = 0 |
|
for n, p in model.named_parameters(): |
|
if p.requires_grad: |
|
print(f'βββ {n}.requires_grad: {p.requires_grad}') |
|
count += 1 |
|
if count == 0: |
|
print('(none)') |
|
|
|
return model, peft_config |
|
|