open_lm
resolving-scaling-law-discrepancies / evaluating_checkpoint.py
TomerPorian's picture
Add files using large-upload tool
e08f01f verified
"""Script to evaluate a model on a validation set. Based on scripts/generate.py from open_lm repo.
"""
import argparse
import json
import re
import torch
from open_lm.evaluate import evaluate_loop
from open_lm.data import get_data
from open_lm.model import create_model
from open_lm.distributed import init_distributed_device
from open_lm.params import parse_args
from scripts.generate_without_hf import Generator, GenerationArgs
def generate_model_jsonl(params):
params_to_width_depth_dict = {5: (96, 3),
7: (128, 4),
9: (160, 5),
15: (224, 6),
22: (288, 8),
28: (320, 9),
37: (384, 10),
57: (480, 12),
84: (576, 14),
108: (640, 15),
149: (704, 18),
220: (832, 21),
347: (1024, 23),
455: (1120, 26),
611: (1312, 26),
901: (1504, 30)
}
width, depth = params_to_width_depth_dict[params]
filepath = f"layers={depth}_hidden-dim={width}.json"
data = {
"hidden_dim": width,
"n_layers": depth,
"n_heads": 4,
"seq_len": 2048,
"vocab_size": 50432,
"post_embed_norm": False,
"weight_tying": False,
"qk_norm": True
}
with open(filepath, 'w') as file:
file.write(json.dumps(data) + '\n')
return filepath
class ModelArgs:
def __init__(self, params, val_data, val_data_key):
default_params = vars(parse_args(""))
for k, v in default_params.items():
setattr(self, k, v)
self.model = generate_model_jsonl(params)
self.val_data = [val_data]
self.val_data_key = [val_data_key]
self.per_gpu_val_batch_size = 16
self.vocab_size = 50432
self.seq_len = 2048
self.wandb = False
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="path/to/checkpoint")
parser.add_argument("--val-data", default="", help="Path to validation data. If empty, generate text.")
parser.add_argument("--val-data-key", default="json.gz")
parser.add_argument("--input-text", default="", type=str, help="Input text to generate from. If empty, evaluate on validation data.")
parser.add_argument("--max-gen-len", default=200, type=int)
parser.add_argument("--temperature", default=0.8, type=float)
parser.add_argument("--top-p", default=0.95, type=float)
args = parser.parse_args()
params = int(re.search(r"params=(\d+)", args.checkpoint).group(1))
checkpoint = torch.load(args.checkpoint)
state_dict = checkpoint["state_dict"]
state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()}
model_args = ModelArgs(params=params, val_data=args.val_data, val_data_key=args.val_data_key)
device = init_distributed_device(model_args)
model_args.device = device
model = create_model(model_args)
model.load_state_dict(state_dict)
model.eval().cuda()
if args.val_data != "":
data = get_data(
model_args,
skip_train=True,
)
metrics = evaluate_loop(model, data["val_list"], 0, model_args, None)
print(metrics)
elif args.input_text != "":
model = model.half()
generator = Generator(model)
input_text = [
args.input_text,
]
output = generator.generate(
input_text,
GenerationArgs(args.max_gen_len, args.temperature, args.top_p),
)
print("".join(output))
else:
print("Please provide either --val-data or --input-text")
if __name__ == "__main__":
main()