open_lm
TomerPorian commited on
Commit
e08f01f
1 Parent(s): b2fa501

Add files using large-upload tool

Browse files
Files changed (1) hide show
  1. evaluating_checkpoint.py +115 -0
evaluating_checkpoint.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to evaluate a model on a validation set. Based on scripts/generate.py from open_lm repo.
2
+ """
3
+ import argparse
4
+ import json
5
+ import re
6
+
7
+ import torch
8
+
9
+ from open_lm.evaluate import evaluate_loop
10
+ from open_lm.data import get_data
11
+ from open_lm.model import create_model
12
+ from open_lm.distributed import init_distributed_device
13
+ from open_lm.params import parse_args
14
+
15
+ from scripts.generate_without_hf import Generator, GenerationArgs
16
+
17
+ def generate_model_jsonl(params):
18
+ params_to_width_depth_dict = {5: (96, 3),
19
+ 7: (128, 4),
20
+ 9: (160, 5),
21
+ 15: (224, 6),
22
+ 22: (288, 8),
23
+ 28: (320, 9),
24
+ 37: (384, 10),
25
+ 57: (480, 12),
26
+ 84: (576, 14),
27
+ 108: (640, 15),
28
+ 149: (704, 18),
29
+ 220: (832, 21),
30
+ 347: (1024, 23),
31
+ 455: (1120, 26),
32
+ 611: (1312, 26),
33
+ 901: (1504, 30)
34
+ }
35
+
36
+ width, depth = params_to_width_depth_dict[params]
37
+ filepath = f"layers={depth}_hidden-dim={width}.json"
38
+ data = {
39
+ "hidden_dim": width,
40
+ "n_layers": depth,
41
+ "n_heads": 4,
42
+ "seq_len": 2048,
43
+ "vocab_size": 50432,
44
+ "post_embed_norm": False,
45
+ "weight_tying": False,
46
+ "qk_norm": True
47
+ }
48
+
49
+ with open(filepath, 'w') as file:
50
+ file.write(json.dumps(data) + '\n')
51
+ return filepath
52
+
53
+
54
+ class ModelArgs:
55
+ def __init__(self, params, val_data, val_data_key):
56
+ default_params = vars(parse_args(""))
57
+ for k, v in default_params.items():
58
+ setattr(self, k, v)
59
+ self.model = generate_model_jsonl(params)
60
+ self.val_data = [val_data]
61
+ self.val_data_key = [val_data_key]
62
+ self.per_gpu_val_batch_size = 16
63
+ self.vocab_size = 50432
64
+ self.seq_len = 2048
65
+ self.wandb = False
66
+
67
+
68
+ def main():
69
+ parser = argparse.ArgumentParser()
70
+ parser.add_argument("--checkpoint", default="path/to/checkpoint")
71
+
72
+ parser.add_argument("--val-data", default="", help="Path to validation data. If empty, generate text.")
73
+ parser.add_argument("--val-data-key", default="json.gz")
74
+
75
+ parser.add_argument("--input-text", default="", type=str, help="Input text to generate from. If empty, evaluate on validation data.")
76
+ parser.add_argument("--max-gen-len", default=200, type=int)
77
+ parser.add_argument("--temperature", default=0.8, type=float)
78
+ parser.add_argument("--top-p", default=0.95, type=float)
79
+
80
+ args = parser.parse_args()
81
+ params = int(re.search(r"params=(\d+)", args.checkpoint).group(1))
82
+
83
+ checkpoint = torch.load(args.checkpoint)
84
+ state_dict = checkpoint["state_dict"]
85
+ state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()}
86
+ model_args = ModelArgs(params=params, val_data=args.val_data, val_data_key=args.val_data_key)
87
+ device = init_distributed_device(model_args)
88
+ model_args.device = device
89
+ model = create_model(model_args)
90
+ model.load_state_dict(state_dict)
91
+ model.eval().cuda()
92
+ if args.val_data != "":
93
+ data = get_data(
94
+ model_args,
95
+ skip_train=True,
96
+ )
97
+ metrics = evaluate_loop(model, data["val_list"], 0, model_args, None)
98
+ print(metrics)
99
+ elif args.input_text != "":
100
+ model = model.half()
101
+ generator = Generator(model)
102
+ input_text = [
103
+ args.input_text,
104
+ ]
105
+ output = generator.generate(
106
+ input_text,
107
+ GenerationArgs(args.max_gen_len, args.temperature, args.top_p),
108
+ )
109
+ print("".join(output))
110
+
111
+ else:
112
+ print("Please provide either --val-data or --input-text")
113
+
114
+ if __name__ == "__main__":
115
+ main()