future-xy commited on
Commit
2088911
1 Parent(s): a549d9d

support mmlu

Browse files
requirements.txt CHANGED
@@ -18,7 +18,7 @@ tqdm
18
  wandb
19
  transformers>=4.36.0
20
  tokenizers>=0.15.0
21
- lm_eval[ifeval] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git
22
  accelerate
23
  sentencepiece
24
  langdetect
 
18
  wandb
19
  transformers>=4.36.0
20
  tokenizers>=0.15.0
21
+ lm_eval[ifeval] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@0.4.2
22
  accelerate
23
  sentencepiece
24
  langdetect
src/backend/hflm_with_measurement.py CHANGED
@@ -68,6 +68,226 @@ class HFLMWithMeasurement(HFLM):
68
  def __init__(self, **kwargs):
69
  super().__init__(**kwargs)
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def _model_generate(self, context, max_length, stop, **generation_kwargs):
72
  # temperature = 0.0 if not set
73
  # if do_sample is false and temp==0.0:
 
68
  def __init__(self, **kwargs):
69
  super().__init__(**kwargs)
70
 
71
+ def _loglikelihood_tokens(
72
+ self,
73
+ requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
74
+ disable_tqdm: bool = False,
75
+ override_bs: int = None,
76
+ ) -> List[Tuple[float, bool]]:
77
+ # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
78
+ res = []
79
+
80
+ def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
81
+ """Defines the key for the sorted method"""
82
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
83
+ # - time estimates will always be over not underestimates, which is more useful for planning
84
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
85
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
86
+ # automatic adaptive batches much much easier to implement
87
+ # - any OOMs will happen right away rather than near the end
88
+
89
+ toks = req[1] + req[2]
90
+ return -len(toks), tuple(toks)
91
+
92
+ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
93
+ """Defines the key to group and lookup one-token continuations"""
94
+ # Use with group_by="contexts" (optional)"
95
+ # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
96
+ # speeds up some multiple-choice tasks proportionally to the number of choices.
97
+ # groups requests by context+continuation[:-1] and infer on one request/group.
98
+ return req[-2] + req[-1][:-1]
99
+
100
+ re_ord = Collator(
101
+ requests,
102
+ sort_fn=_collate,
103
+ group_by="contexts"
104
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
105
+ and self.logits_cache
106
+ else None,
107
+ group_fn=_lookup_one_token_cont,
108
+ )
109
+
110
+ # automatic (variable) batch size detection for vectorization
111
+ # pull longest context sample from request
112
+ n_reordered_requests = len(re_ord)
113
+ batch_size = (
114
+ self.batch_size
115
+ if self.batch_size != "auto"
116
+ else override_bs
117
+ if override_bs is not None
118
+ else 0
119
+ )
120
+ batch_fn = (
121
+ self._batch_scheduler
122
+ if self.batch_size == "auto"
123
+ and n_reordered_requests > 0
124
+ and not override_bs
125
+ else None
126
+ )
127
+
128
+ chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
129
+ pbar = tqdm(
130
+ total=len(requests),
131
+ disable=(disable_tqdm or (self.rank != 0)),
132
+ desc="Running loglikelihood requests",
133
+ )
134
+ for chunk in chunks:
135
+ inps = []
136
+ cont_toks_list = []
137
+ inplens = []
138
+
139
+ conts = []
140
+ encoder_attns = []
141
+
142
+ padding_len_inp = None
143
+ padding_len_cont = None
144
+ # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
145
+ # tensors, then we pack them together into a batch, call the model, and then pick it all apart
146
+ # again because vectorizing is annoying
147
+
148
+ for _, context_enc, continuation_enc in chunk:
149
+ # sanity check
150
+ assert len(context_enc) > 0
151
+ assert len(continuation_enc) > 0
152
+ assert len(continuation_enc) <= self.max_length
153
+
154
+ # how this all works (illustrated on a causal decoder-only setup):
155
+ # CTX CONT
156
+ # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
157
+ # model \ \
158
+ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
159
+ # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
160
+
161
+ # when too long to fit in context, truncate from the left
162
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
163
+ inp = torch.tensor(
164
+ (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
165
+ dtype=torch.long,
166
+ device=self.device,
167
+ )
168
+ (inplen,) = inp.shape
169
+ elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
170
+ inp = torch.tensor(
171
+ (context_enc)[-self.max_length :],
172
+ dtype=torch.long,
173
+ device=self.device,
174
+ )
175
+ (inplen,) = inp.shape
176
+
177
+ # build encoder attn masks
178
+ encoder_attns.append(torch.ones_like(inp))
179
+
180
+ cont = torch.tensor(
181
+ (continuation_enc)[-self.max_length :],
182
+ # TODO: left-shift these?
183
+ # TODO: our code assumes we never end up truncating conts for either model type
184
+ dtype=torch.long,
185
+ device=self.device,
186
+ )
187
+ (contlen,) = cont.shape
188
+
189
+ conts.append(cont)
190
+
191
+ padding_len_cont = (
192
+ max(padding_len_cont, contlen)
193
+ if padding_len_cont is not None
194
+ else contlen
195
+ )
196
+
197
+ padding_len_inp = (
198
+ max(padding_len_inp, inplen)
199
+ if padding_len_inp is not None
200
+ else inplen
201
+ )
202
+
203
+ inps.append(inp) # [1, inp_length]
204
+ cont_toks_list.append(continuation_enc)
205
+ inplens.append(inplen)
206
+
207
+ # create encoder attn mask and batched conts, if seq2seq
208
+ call_kwargs = {}
209
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
210
+ batched_inps = pad_and_concat(
211
+ padding_len_inp, inps, padding_side="right"
212
+ ) # [batch, padding_len_inp]
213
+ elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
214
+ # TODO: left-pad encoder inps and mask?
215
+ batched_inps = pad_and_concat(
216
+ padding_len_inp, inps
217
+ ) # [batch, padding_len_inp]
218
+ batched_conts = pad_and_concat(
219
+ padding_len_cont, conts
220
+ ) # [batch, padding_len_cont]
221
+ batched_encoder_mask = pad_and_concat(
222
+ padding_len_inp, encoder_attns
223
+ ) # [batch, padding_len_inp]
224
+ call_kwargs = {
225
+ "attn_mask": batched_encoder_mask,
226
+ "labels": batched_conts,
227
+ }
228
+
229
+ start = time()
230
+ intermediate_res = self._model_call(batched_inps, **call_kwargs)
231
+ end = time()
232
+ multi_logits = F.log_softmax(
233
+ intermediate_res , dim=-1
234
+ ) # [batch, padding_length (inp or cont), vocab]
235
+ per_sample_time = (end - start) / len(multi_logits)
236
+
237
+ for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
238
+ chunk, multi_logits, inplens, cont_toks_list
239
+ ):
240
+ # Slice to original seq length
241
+ contlen = len(cont_toks)
242
+ # take only logits in the continuation
243
+ # (discard context toks if decoder-only ; discard right-padding)
244
+ # also discards + checks for "virtual tokens" in the causal LM's input window
245
+ # from prompt/prefix tuning tokens, if applicable
246
+ ctx_len = (
247
+ inplen + (logits.shape[0] - padding_len_inp)
248
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
249
+ else None
250
+ )
251
+ logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
252
+ logits = logits.unsqueeze(0) # [1, seq, vocab]
253
+
254
+ # Check if per-token argmax is exactly equal to continuation
255
+ greedy_tokens = logits.argmax(dim=-1)
256
+
257
+ # check for one-token continuation cache hits.
258
+ # noop in case group_by != "contexts" or no cache hit and returns the
259
+ # original args. Otherwise, expands the logits batch dimension and yields each
260
+ # batch along with matching continuation tokens and prompt strings.
261
+ # logits -> [1, seq, vocab]
262
+ for request_str, cont_toks, logits in re_ord.get_cache(
263
+ req_str=request_str,
264
+ cxt_toks=ctx_tokens,
265
+ cont_toks=cont_toks,
266
+ logits=logits,
267
+ ):
268
+ cont_toks = torch.tensor(
269
+ cont_toks, dtype=torch.long, device=self.device
270
+ ).unsqueeze(0) # [1, seq]
271
+ max_equal = (greedy_tokens == cont_toks).all()
272
+
273
+ # Obtain log-probs at the corresponding continuation token indices
274
+ # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
275
+ logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
276
+ -1
277
+ ) # [1, seq]
278
+
279
+ # Answer: (log prob, is-exact-match)
280
+ answer = (float(logits.sum()), bool(max_equal))
281
+
282
+ res.append((answer, per_sample_time, 0, 0))
283
+
284
+ self.cache_hook.add_partial("loglikelihood", request_str, answer)
285
+ pbar.update(1)
286
+
287
+ pbar.close()
288
+
289
+ return re_ord.get_original(res)
290
+
291
  def _model_generate(self, context, max_length, stop, **generation_kwargs):
292
  # temperature = 0.0 if not set
293
  # if do_sample is false and temp==0.0:
src/backend/run_eval_suite.py CHANGED
@@ -1,13 +1,57 @@
1
  from lm_eval import evaluator
2
  from lm_eval.tasks import TaskManager
 
 
3
 
4
  from src.backend.manage_requests import EvalRequest
5
 
6
- from src.backend.tasks.xsum.task import XSum
7
- from src.backend.tasks.xsum.task_v2 import XSumv2
8
 
9
- from src.backend.tasks.cnndm.task import CNNDM
10
- from src.backend.tasks.cnndm.task_v2 import CNNDMv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  from src.backend.tasks.selfcheckgpt.task import SelfCheckGPT
13
 
 
1
  from lm_eval import evaluator
2
  from lm_eval.tasks import TaskManager
3
+ from lm_eval.api.metrics import mean
4
+ from lm_eval.api.task import ConfigurableTask
5
 
6
  from src.backend.manage_requests import EvalRequest
7
 
 
 
8
 
9
+ orig_process_results = ConfigurableTask.process_results
10
+ orig_aggregation = ConfigurableTask.aggregation
11
+ orig_higher_is_better = ConfigurableTask.higher_is_better
12
+
13
+ def process_results_decorator(func):
14
+ def wrapper(self, doc, results, *args, **kwargs):
15
+ processed_results = [r[0] for r in results]
16
+
17
+ end_to_end_time = sum([r[1] for r in results]) / len(results)
18
+ prefilling_time = sum([r[2] for r in results]) / len(results)
19
+ decoding_throughput = sum([r[3] for r in results]) / len(results)
20
+ # print(f"end_to_end_time: {end_to_end_time}, prefilling_time: {prefilling_time}, decoding_throughput: {decoding_throughput}")
21
+
22
+ result_dict = func(self, doc, processed_results, *args, **kwargs)
23
+ result_dict["end_to_end_time"] = end_to_end_time
24
+ result_dict["prefilling_time"] = prefilling_time
25
+ result_dict["decoding_throughput"] = decoding_throughput
26
+ return result_dict
27
+ return wrapper
28
+ ConfigurableTask.process_results = process_results_decorator(orig_process_results)
29
+
30
+ def aggregation_decorator(func):
31
+ def wrapper(self, *args, **kwargs):
32
+ aggregation_list = func(self, *args, **kwargs)
33
+ aggregation_list["end_to_end_time"] = mean
34
+ aggregation_list["prefilling_time"] = mean
35
+ aggregation_list["decoding_throughput"] = mean
36
+ return aggregation_list
37
+ return wrapper
38
+ ConfigurableTask.aggregation = aggregation_decorator(orig_aggregation)
39
+
40
+ def higher_is_better_decorator(func):
41
+ def wrapper(self, *args, **kwargs):
42
+ higher_is_better_dict = func(self, *args, **kwargs)
43
+ higher_is_better_dict["end_to_end_time"] = False
44
+ higher_is_better_dict["prefilling_time"] = False
45
+ higher_is_better_dict["decoding_throughput"] = True
46
+ return higher_is_better_dict
47
+ return wrapper
48
+ ConfigurableTask.higher_is_better = higher_is_better_decorator(orig_higher_is_better)
49
+
50
+ # from src.backend.tasks.xsum.task import XSum
51
+ # from src.backend.tasks.xsum.task_v2 import XSumv2
52
+
53
+ # from src.backend.tasks.cnndm.task import CNNDM
54
+ # from src.backend.tasks.cnndm.task_v2 import CNNDMv2
55
 
56
  from src.backend.tasks.selfcheckgpt.task import SelfCheckGPT
57
 
src/backend/tasks/measurement_task_utils.py CHANGED
@@ -12,7 +12,7 @@ def process_results_decorator(func):
12
  end_to_end_time = sum([r[1] for r in results]) / len(results)
13
  prefilling_time = sum([r[2] for r in results]) / len(results)
14
  decoding_throughput = sum([r[3] for r in results]) / len(results)
15
- print(f"end_to_end_time: {end_to_end_time}, prefilling_time: {prefilling_time}, decoding_throughput: {decoding_throughput}")
16
 
17
  # Now call the original process_results with the processed results
18
  result_dict = func(self, doc, processed_results, *args, **kwargs)
 
12
  end_to_end_time = sum([r[1] for r in results]) / len(results)
13
  prefilling_time = sum([r[2] for r in results]) / len(results)
14
  decoding_throughput = sum([r[3] for r in results]) / len(results)
15
+ # print(f"end_to_end_time: {end_to_end_time}, prefilling_time: {prefilling_time}, decoding_throughput: {decoding_throughput}")
16
 
17
  # Now call the original process_results with the processed results
18
  result_dict = func(self, doc, processed_results, *args, **kwargs)