File size: 9,825 Bytes
f0ad559 1ae96c8 f0ad559 1ae96c8 f0ad559 1ae96c8 f0ad559 1ae96c8 f0ad559 1ae96c8 f0ad559 1ae96c8 f0ad559 1ae96c8 f0ad559 1ae96c8 f0ad559 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
import copy
import os
from datetime import timedelta
from time import time
from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import transformers
from accelerate import (
Accelerator,
DistributedType,
InitProcessGroupKwargs,
find_executable_batch_size,
)
from packaging import version
from peft import PeftModel
from peft import __version__ as PEFT_VERSION
from tqdm import tqdm
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
)
from transformers import TextStreamer
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import (
Collator,
clear_torch_cache,
get_dtype,
pad_and_concat,
stop_sequences_criteria,
)
from lm_eval.models.huggingface import HFLM
class StopWatch(TextStreamer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.start_prefilling = None
self.prefilling_time = None
self.start_decoding = None
self.decoding_time = None
self.decoding_iterations = 0
def put(self, value):
if self.start_prefilling is None:
self.start_prefilling = time()
return
elif self.prefilling_time is None:
self.prefilling_time = time() - self.start_prefilling
self.start_decoding = time()
self.decoding_iterations += 1
return
def end(self):
if self.decoding_time is None and self.start_decoding is not None:
self.decoding_time = time() - self.start_decoding
return
class HFLMWithMeasurement(HFLM):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _model_generate(self, context, max_length, stop, **generation_kwargs):
# temperature = 0.0 if not set
# if do_sample is false and temp==0.0:
# remove temperature, as do_sample=False takes care of this
# and we don't want a warning from HF
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature")
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, context.shape[1], context.shape[0]
)
stop_watch = StopWatch(self.tokenizer)
start = time()
res = self.model.generate(
input_ids=context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
streamer=stop_watch,
**generation_kwargs,
)
end = time()
batch_size = context.shape[0]
output_length = stop_watch.decoding_iterations
end_to_end_time = (end - start) / batch_size
prefilling_time = stop_watch.prefilling_time / batch_size
decoding_time = stop_watch.decoding_time / batch_size
token_per_sec = output_length / decoding_time
return res, end_to_end_time, prefilling_time, token_per_sec
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
res = []
def _collate(req: Tuple[str, dict]):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(req[0])
return -len(toks), req[0]
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests",
)
adaptive_batch_size = None
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")
batch_size = self._detect_batch_size()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
# for each different set of kwargs, we execute all requests, by batch.
batch_size = (
self.batch_size
if self.batch_size != "auto"
else adaptive_batch_size
if adaptive_batch_size is not None
else 0
)
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto" and not adaptive_batch_size
else None
)
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
# group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
re_ords = Collator(
[reg.args for reg in requests],
sort_fn=_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
# encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode(
contexts,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
if "max_length" not in kwargs:
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
# perform batched generation
cont, end_to_end_time, prefilling_time, token_per_sec = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
stop=until,
**kwargs,
)
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks = cont_toks[context_enc.shape[1] :]
s = self.tok_decode(cont_toks)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
res.append((s, end_to_end_time, prefilling_time, token_per_sec))
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
pbar.close()
return res
|