|
""" |
|
Alternative way to load trained models for evaluation |
|
""" |
|
import copy |
|
import sys |
|
from os.path import join |
|
from omegaconf import OmegaConf |
|
|
|
import torch |
|
|
|
from src.utils.logging import print_header, print_config, _format_arg |
|
from .pretrained import get_pretrained_loader |
|
from .peft import create_peft_config |
|
from .load_model import load_and_convert_attns |
|
from .convert_model import remove_base_attention, toggle_attention |
|
|
|
|
|
def get_args_from_checkpoint(fname: str): |
|
""" |
|
Get arguments from checkpoint filename |
|
""" |
|
id_to_name = { |
|
'lk': 'learned_kernel', |
|
'tqk': 'tie_qk_kernels', |
|
'tq': 'train_qk', |
|
'lzi': 'lk_zero_init', |
|
'lsc': 'lk_skip_connection', |
|
'pmnop': 'pretrained_model_name_or_path', |
|
} |
|
id_to_type = { |
|
'lk': str, |
|
'tqk': bool, |
|
'tq': bool, |
|
'lzi': bool, |
|
'lsc': bool, |
|
'pmnop': str, |
|
} |
|
args = {v: None for k, v in id_to_name.items()} |
|
args['run_name'] = '' |
|
|
|
for id_val in fname.split('-'): |
|
try: |
|
_id, val = id_val.split('=') |
|
if val[-len('_distill.pt'):] == '_distill.pt': |
|
val = val[:-len('_distill.pt')] |
|
if _id in id_to_type: |
|
_type = id_to_type[_id] |
|
args[id_to_name[_id]] = _type(val) |
|
except Exception: |
|
pass |
|
return OmegaConf.create(args) |
|
|
|
|
|
def update_model_config_from_args(model_config, args): |
|
"""Override default configs""" |
|
|
|
for arg in ['learned_kernel', 'tie_qk_kernels', 'train_qk']: |
|
argval = getattr(args, arg) |
|
if argval is not None: |
|
setattr(model_config['attention'], arg, argval) |
|
args.run_name += f'-{_format_arg(arg)}={argval}' |
|
|
|
for arg in ['lk_skip_connection', 'lk_zero_init']: |
|
argval = getattr(args, arg) |
|
if argval is not None: |
|
setattr(model_config['attention']['learned_kernel_kwargs'], |
|
arg[len('lk_'):], argval) |
|
args.run_name += f'-{_format_arg(arg)}={argval}' |
|
|
|
if args.pretrained_model_name_or_path is not None: |
|
pmnop = args.pretrained_model_name_or_path |
|
model_config.model.pretrained_model_name_or_path = pmnop |
|
args.run_name += f'-pmnop={pmnop.split("/")[-1]}' |
|
return model_config |
|
|
|
|
|
|
|
def get_lm_eval_model(model_kwargs: dict, |
|
path_to_lm_eval_harness: str, |
|
hedgehog_model: bool = False, |
|
long_model: bool = False, |
|
): |
|
""" |
|
Load model for evaluation using LM Evaluation Harness |
|
""" |
|
lm_kwargs = copy.deepcopy(model_kwargs) |
|
lm_kwargs['pretrained'] = lm_kwargs['pretrained_model_name_or_path'] |
|
lm_kwargs['dtype'] = str(lm_kwargs['torch_dtype']).split('.')[-1] |
|
del lm_kwargs['torch_dtype'] |
|
|
|
|
|
lm_kwargs['output_attentions'] = False |
|
lm_kwargs['output_hidden_states'] = False |
|
|
|
print('-> Loading as lm-evaluation-harness model') |
|
if hedgehog_model: |
|
if 'mistral' in lm_kwargs['pretrained']: |
|
from lm_eval_harness.models import LolcatsMistralForCausalLM as ModelClass |
|
else: |
|
from lm_eval_harness.models import LolcatsLlamaForCausalLM as ModelClass |
|
lm = ModelClass.create_from_arg_string('', lm_kwargs) |
|
else: |
|
sys.path.append(path_to_lm_eval_harness) |
|
from lm_eval.models import get_model |
|
lm = get_model('hf-causal-experimental').create_from_arg_string('', lm_kwargs) |
|
return lm |
|
|
|
|
|
def load_model_from_config(model_config_name: str, |
|
config_dir: str = './configs', |
|
lm_eval_model: bool = False, |
|
path_to_lm_eval_harness: str = '/juice2/scr2/mzhang/projects/lm-evaluation-harness', |
|
): |
|
""" |
|
Load model from a config file |
|
""" |
|
|
|
model_config_path = join(config_dir, 'model', f'{model_config_name}.yaml') |
|
model_config = OmegaConf.load(model_config_path) |
|
|
|
model_loader = get_pretrained_loader(**model_config.model) |
|
tokenizer = model_loader.load_tokenizer() |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
tokenizer.padding_side = 'left' |
|
|
|
if lm_eval_model: |
|
lm = get_lm_eval_model(model_loader.loading_kwargs, path_to_lm_eval_harness) |
|
model = lm.model |
|
else: |
|
model = model_loader.load() |
|
|
|
model.eval() |
|
if lm_eval_model: |
|
lm.model = model |
|
model = lm |
|
return model, model_config, tokenizer |
|
|
|
|
|
def load_model_from_checkpoint(attn_mlp_checkpoint_path: str = None, |
|
finetune_checkpoint_path: str = None, |
|
config_dir: str = './configs', |
|
print_model: bool = False, |
|
debug: bool = False, |
|
lm_eval_model: bool = False, |
|
path_to_lm_eval_harness: str = '/juice2/scr2/mzhang/projects/lm-evaluation-harness', |
|
profile_model: bool = False, |
|
): |
|
""" |
|
Load model architecture from a checkpoint path |
|
-> attn_mlp_checkpoint_path should direct to checkpoint with learned MLPs |
|
-> finetune_checkpoint_path should direct to checkpoint with all other parameters |
|
-> Assumes checkpoint_path stings have names for model_config and finetune_configs |
|
""" |
|
|
|
|
|
if attn_mlp_checkpoint_path is not None: |
|
if len(attn_mlp_checkpoint_path.split('/')) == 4: |
|
model_config = attn_mlp_checkpoint_path.split('/')[2] |
|
else: |
|
model_config = attn_mlp_checkpoint_path.split('/')[-1].split('-m=')[-1].split('-')[0] |
|
model_config_path = join(config_dir, 'model', f'{model_config}.yaml') |
|
model_config = OmegaConf.load(model_config_path) |
|
args = get_args_from_checkpoint(attn_mlp_checkpoint_path.split('/')[-1]) |
|
model_config = update_model_config_from_args(model_config, args) |
|
else: |
|
if len(finetune_checkpoint_path.split('/')) == 4: |
|
model_config = finetune_checkpoint_path.split('/')[2] |
|
else: |
|
model_config = finetune_checkpoint_path.split('/')[-1].split('-m=')[-1].split('-')[0] |
|
model_config_path = join(config_dir, 'model', f'{model_config}.yaml') |
|
model_config = OmegaConf.load(model_config_path) |
|
|
|
if profile_model: |
|
model_config['attention']['attention_type'] += '_profile' |
|
|
|
if finetune_checkpoint_path is not None: |
|
finetune_config = finetune_checkpoint_path.split('-f=')[-1].split('-')[0] |
|
finetune_config_path = join(config_dir, 'experiment', f'{finetune_config}.yaml') |
|
finetune_config = OmegaConf.load(finetune_config_path) |
|
|
|
if debug: |
|
print_header('-- Model Config --') |
|
print_config(model_config) |
|
try: |
|
print_header('-- Finetune Config --') |
|
print_config(finetune_config) |
|
except NameError: |
|
pass |
|
|
|
|
|
model_loader = get_pretrained_loader(**model_config.model) |
|
tokenizer = model_loader.load_tokenizer() |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
tokenizer.padding_side = 'left' |
|
|
|
if lm_eval_model and attn_mlp_checkpoint_path is not None: |
|
lm = get_lm_eval_model(model_loader.loading_kwargs, path_to_lm_eval_harness, |
|
hedgehog_model=True) |
|
model = lm.model |
|
elif lm_eval_model: |
|
lm = get_lm_eval_model(model_loader.loading_kwargs, path_to_lm_eval_harness) |
|
model = lm.model |
|
elif attn_mlp_checkpoint_path is None: |
|
model = model_loader.load() |
|
else: |
|
model = model_loader.load(model_type=model_config['attention']['attention_type']) |
|
try: |
|
model.state_chunk_len = model_config['attention']['state_chunk_len'] |
|
except KeyError: |
|
pass |
|
|
|
if attn_mlp_checkpoint_path is not None: |
|
|
|
model = load_and_convert_attns(model, model_config, |
|
checkpoint_path=attn_mlp_checkpoint_path)[0] |
|
if 'peft' in model_config['attention']: |
|
model = model.merge_and_unload() |
|
|
|
|
|
model = toggle_attention(model, False) |
|
if debug: |
|
print_header('*** Model after attention converion ***') |
|
print(model) |
|
|
|
if finetune_checkpoint_path is not None: |
|
|
|
if finetune_config.finetune.method == 'lora': |
|
model, _ = create_peft_config(model, finetune_config.finetune) |
|
else: |
|
for p in model.parameters(): |
|
p.requires_grad = True |
|
|
|
|
|
state_dict = torch.load(finetune_checkpoint_path)['model_state_dict'] |
|
_keys = model.load_state_dict(state_dict, strict=False) |
|
try: |
|
assert len(_keys.unexpected_keys) == 0 |
|
print_header('*** All expected keys matched successfully ***') |
|
except AssertionError: |
|
print_header('*** Error: unexpected keys in checkpoint ***') |
|
print('Unexpected keys:') |
|
for k in _keys.unexpected_keys: |
|
print(k) |
|
if debug: |
|
print_header('Missing keys:') |
|
for k in _keys.missing_keys: |
|
print(k) |
|
print_header('Unexpected keys:') |
|
for k in _keys.unexpected_keys: |
|
print(k) |
|
|
|
try: |
|
|
|
print('-> Training attention:', model.model.layers[0].self_attn.train_attention) |
|
except AttributeError as e: |
|
print('Error at:', e) |
|
_train_attn = model.model.model.layers[0].self_attn.train_attention |
|
print(f"But it's ok, {type(model.model.model)} has attribute 'layers'") |
|
print('-> Training attention:', _train_attn) |
|
|
|
|
|
if print_model or debug: |
|
print_header('*** Model ***') |
|
print(model) |
|
print_header('*** Trainable Parameters ***') |
|
for n, p in model.named_parameters(): |
|
if p.requires_grad: |
|
print(f'βββ {n}.requires_grad: {p.requires_grad}') |
|
model.eval() |
|
if lm_eval_model: |
|
lm.model = model |
|
model = lm |
|
return model, model_config, tokenizer |
|
|