import copy import os from datetime import timedelta import sys from time import time from pathlib import Path from typing import List, Literal, Optional, Tuple, Union from calflops import calculate_flops import torch import torch.nn.functional as F import transformers from accelerate import ( Accelerator, DistributedType, InitProcessGroupKwargs, find_executable_batch_size, ) from packaging import version from peft import PeftModel from peft import __version__ as PEFT_VERSION from tqdm import tqdm from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, ) from transformers import TextStreamer from lm_eval import utils from lm_eval.api.instance import Instance from lm_eval.api.model import TemplateLM from lm_eval.api.registry import register_model from lm_eval.models.utils import ( Collator, clear_torch_cache, get_dtype, pad_and_concat, stop_sequences_criteria, ) from lm_eval.models.huggingface import HFLM from src.utils import get_gpu_number, get_gpu_details, get_peak_bw, transfer_precision2bytes, get_peak_flops from src.submission.check_validity import get_model_size from src.envs import API class StopWatch(TextStreamer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.start_prefilling = None self.prefilling_time = None self.start_decoding = None self.decoding_time = None self.decoding_iterations = 0 def put(self, value): if self.start_prefilling is None: self.start_prefilling = time() return elif self.prefilling_time is None: self.prefilling_time = time() - self.start_prefilling self.start_decoding = time() self.decoding_iterations += 1 return def end(self): if self.decoding_time is None and self.start_decoding is not None: self.decoding_time = time() - self.start_decoding return class HFLMWithMeasurement(HFLM): def __init__(self, **kwargs): super().__init__(**kwargs) self.pretrained = kwargs.get("pretrained", None) self.revision = kwargs.get("revision", None) self.precision = kwargs.get("dtype", None) self.total_flops = 0 def _loglikelihood_tokens( self, requests: List[Tuple[Tuple[str, str], List[int], List[int]]], disable_tqdm: bool = False, override_bs: int = None, ) -> List[Tuple[float, bool]]: # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context res = [] def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): """Defines the key for the sorted method""" # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = req[1] + req[2] return -len(toks), tuple(toks) def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): """Defines the key to group and lookup one-token continuations""" # Use with group_by="contexts" (optional)" # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. # speeds up some multiple-choice tasks proportionally to the number of choices. # groups requests by context+continuation[:-1] and infer on one request/group. return req[-2] + req[-1][:-1] re_ord = Collator( requests, sort_fn=_collate, group_by="contexts" if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM and self.logits_cache else None, group_fn=_lookup_one_token_cont, ) # automatic (variable) batch size detection for vectorization # pull longest context sample from request n_reordered_requests = len(re_ord) batch_size = ( self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0 ) batch_fn = ( self._batch_scheduler if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs else None ) chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) pbar = tqdm( total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running loglikelihood requests", ) for chunk in chunks: inps = [] cont_toks_list = [] inplens = [] conts = [] encoder_attns = [] padding_len_inp = None padding_len_cont = None # because vectorizing is annoying, we first convert each (context, continuation) pair to padded # tensors, then we pack them together into a batch, call the model, and then pick it all apart # again because vectorizing is annoying for _, context_enc, continuation_enc in chunk: # sanity check assert len(context_enc) > 0 assert len(continuation_enc) > 0 assert len(continuation_enc) <= self.max_length # how this all works (illustrated on a causal decoder-only setup): # CTX CONT # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] # model \ \ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # when too long to fit in context, truncate from the left if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: inp = torch.tensor( (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], dtype=torch.long, device=self.device, ) (inplen,) = inp.shape elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: inp = torch.tensor( (context_enc)[-self.max_length :], dtype=torch.long, device=self.device, ) (inplen,) = inp.shape # build encoder attn masks encoder_attns.append(torch.ones_like(inp)) cont = torch.tensor( (continuation_enc)[-self.max_length :], # TODO: left-shift these? # TODO: our code assumes we never end up truncating conts for either model type dtype=torch.long, device=self.device, ) (contlen,) = cont.shape conts.append(cont) padding_len_cont = ( max(padding_len_cont, contlen) if padding_len_cont is not None else contlen ) padding_len_inp = ( max(padding_len_inp, inplen) if padding_len_inp is not None else inplen ) inps.append(inp) # [1, inp_length] cont_toks_list.append(continuation_enc) inplens.append(inplen) # create encoder attn mask and batched conts, if seq2seq call_kwargs = {} if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: batched_inps = pad_and_concat( padding_len_inp, inps, padding_side="right" ) # [batch, padding_len_inp] elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: # TODO: left-pad encoder inps and mask? batched_inps = pad_and_concat( padding_len_inp, inps ) # [batch, padding_len_inp] batched_conts = pad_and_concat( padding_len_cont, conts ) # [batch, padding_len_cont] batched_encoder_mask = pad_and_concat( padding_len_inp, encoder_attns ) # [batch, padding_len_inp] call_kwargs = { "attn_mask": batched_encoder_mask, "labels": batched_conts, } start = time() intermediate_res = self._model_call(batched_inps, **call_kwargs) end = time() multi_logits = F.log_softmax( intermediate_res , dim=-1 ) # [batch, padding_length (inp or cont), vocab] per_sample_time = (end - start) / len(multi_logits) for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( chunk, multi_logits, inplens, cont_toks_list ): # Slice to original seq length contlen = len(cont_toks) # take only logits in the continuation # (discard context toks if decoder-only ; discard right-padding) # also discards + checks for "virtual tokens" in the causal LM's input window # from prompt/prefix tuning tokens, if applicable ctx_len = ( inplen + (logits.shape[0] - padding_len_inp) if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM else None ) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = logits.unsqueeze(0) # [1, seq, vocab] # Check if per-token argmax is exactly equal to continuation greedy_tokens = logits.argmax(dim=-1) # check for one-token continuation cache hits. # noop in case group_by != "contexts" or no cache hit and returns the # original args. Otherwise, expands the logits batch dimension and yields each # batch along with matching continuation tokens and prompt strings. # logits -> [1, seq, vocab] for request_str, cont_toks, logits in re_ord.get_cache( req_str=request_str, cxt_toks=ctx_tokens, cont_toks=cont_toks, logits=logits, ): cont_toks = torch.tensor( cont_toks, dtype=torch.long, device=self.device ).unsqueeze(0) # [1, seq] max_equal = (greedy_tokens == cont_toks).all() # Obtain log-probs at the corresponding continuation token indices # last_token_slice = logits[:, -1, :].squeeze(0).tolist() logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( -1 ) # [1, seq] # Answer: (log prob, is-exact-match) answer = (float(logits.sum()), bool(max_equal)) res.append((answer, per_sample_time, 0, 0)) self.cache_hook.add_partial("loglikelihood", request_str, answer) pbar.update(1) pbar.close() return re_ord.get_original(res) def _model_generate(self, context, max_length, stop, **generation_kwargs): # temperature = 0.0 if not set # if do_sample is false and temp==0.0: # remove temperature, as do_sample=False takes care of this # and we don't want a warning from HF generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) do_sample = generation_kwargs.get("do_sample", None) is_gsm8k = generation_kwargs.get("is_gsm8k", False) # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies if generation_kwargs.get("temperature") == 0.0 and do_sample is None: generation_kwargs["do_sample"] = do_sample = False if do_sample is False and generation_kwargs.get("temperature") == 0.0: generation_kwargs.pop("temperature") generation_kwargs.pop("is_gsm8k") context_length = context.shape[1] if not is_gsm8k: # build stopping criteria stopping_criteria = stop_sequences_criteria( self.tokenizer, stop, context.shape[1], context.shape[0] ) stop_watch = StopWatch(self.tokenizer) start = time() res = self.model.generate( input_ids=context, max_length=max_length, stopping_criteria=stopping_criteria, pad_token_id=self.tokenizer.pad_token_id, use_cache=True, streamer=stop_watch, **generation_kwargs, ) end = time() else: # print("Using GSM8K") stop_watch = StopWatch(self.tokenizer) start = time() res = self.model.generate( input_ids=context, max_length=max_length, eos_token_id=stop, pad_token_id=self.tokenizer.pad_token_id, use_cache=True, streamer=stop_watch, **generation_kwargs, ) end = time() batch_size = context.shape[0] output_length = stop_watch.decoding_iterations precision_bytes = transfer_precision2bytes(self.precision) model_info = API.model_info(repo_id=self.pretrained, revision=self.revision) model_size_param = get_model_size(model_info=model_info, precision=self.precision) model_size = model_size_param * precision_bytes model_config = self.model.config n_layers = model_config.num_hidden_layers if hasattr(model_config, "num_hidden_layers") else model_config.num_layers d_model = model_config.hidden_size if hasattr(model_config, "hidden_size") else model_config.d_model if hasattr(model_config, "num_experts_per_tok"): n_experts_per_tok = model_config.num_experts_per_tok elif hasattr(model_config, "num_selected_experts"): n_experts_per_tok = model_config.num_selected_experts else: n_experts_per_tok = 1 if hasattr(model_config, "ffn_dim"): d_ff = model_config.ffn_dim elif hasattr(model_config, "intermediate_size"): d_ff = model_config.intermediate_size elif hasattr(model_config, "d_ff"): d_ff = model_config.d_ff else: raise ValueError("Unknown ffn dim model configuration") if hasattr(model_config, "num_local_experts"): num_experts = model_config.num_local_experts elif hasattr(model_config, "num_experts"): num_experts = model_config.num_experts else: num_experts = 1 ffn_params = n_layers * d_ff * 2 * d_model shared_params = model_size_param * 1e9 - num_experts * ffn_params model_size = shared_params + n_experts_per_tok * ffn_params per_token_kv_size = 2 * n_layers * d_model * precision_bytes peak_bw_single = get_peak_bw(get_gpu_details()) peak_bw = peak_bw_single * get_gpu_number() kv_size = (output_length - 1) * per_token_kv_size / 1e9 end_to_end_time = (end - start) / batch_size prefilling_time = stop_watch.prefilling_time / batch_size decoding_time = stop_watch.decoding_time / batch_size token_per_sec = output_length / decoding_time ach_mem_bw = (model_size / 1e9 + kv_size) * token_per_sec flops_per_token = 2 * model_size + 2 * n_layers * context_length * d_model peak_flops_single = get_peak_flops(get_gpu_details(), self.precision) peak_flops = peak_flops_single * get_gpu_number() ## TODO only support llama-type decoder only models and moe models of switch transformer and mixtrial mfu = token_per_sec * flops_per_token / peak_flops mbu = ach_mem_bw / peak_bw # print(f"mfu: {mfu}, mbu: {mbu}") return res, end_to_end_time, prefilling_time, token_per_sec, mfu, mbu def generate_until( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[str]: res = [] def _collate(req: Tuple[str, dict]): """Defines the key for the sorted method""" # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = self.tok_encode(req[0]) return -len(toks), req[0] pbar = tqdm( total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running generate_until requests", ) adaptive_batch_size = None if self.batch_size == "auto": # using rolling window with maximum context print("Passed argument batch_size = auto. Detecting largest batch size") batch_size = self._detect_batch_size() print(f"Determined Largest batch size: {batch_size}") adaptive_batch_size = batch_size # for each different set of kwargs, we execute all requests, by batch. batch_size = ( self.batch_size if self.batch_size != "auto" else adaptive_batch_size if adaptive_batch_size is not None else 0 ) batch_fn = ( self._batch_scheduler if self.batch_size == "auto" and not adaptive_batch_size else None ) # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) re_ords = Collator( [reg.args for reg in requests], sort_fn=_collate, group_by="gen_kwargs", group_fn=lambda x: x[1], ) chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) for chunk in chunks: contexts, all_gen_kwargs = zip(*chunk) # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. gen_kwargs = all_gen_kwargs[0] # unpack our keyword arguments. until = None if isinstance(gen_kwargs, dict): kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 if "until" in kwargs.keys(): until = kwargs.pop("until") if isinstance(until, str): until = [kwargs] elif not isinstance(until, list): raise ValueError( f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" ) else: raise ValueError( f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" ) # add EOS token to stop sequences eos = self.tok_decode(self.eot_token_id) if not until: until = [eos] else: until.append(eos) is_gsm8k = kwargs.get("is_gsm8k", False) if is_gsm8k: until = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")] if "max_gen_toks" in kwargs.keys(): max_gen_toks = kwargs.pop("max_gen_toks") else: max_gen_toks = self.max_gen_toks # set the max length in tokens of inputs ("context_enc") if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: # max len for inputs = max length, minus room to generate the max new tokens max_ctx_len = self.max_length - max_gen_toks elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: # max len for inputs = encoder's whole max_length max_ctx_len = self.max_length # encode, pad, and truncate contexts for this batch context_enc, attn_masks = self.tok_batch_encode( contexts, left_truncate_len=max_ctx_len, truncation=self.truncation, ) # print("context: ", self.tok_decode(context_enc[0])) context_enc = context_enc.to(self.device) attn_masks = attn_masks.to(self.device) if "max_length" not in kwargs: kwargs["max_length"] = context_enc.shape[1] + max_gen_toks # perform batched generation cont, end_to_end_time, prefilling_time, token_per_sec, mfu, mbu = self._model_generate( context=context_enc, attention_mask=attn_masks, stop=until, **kwargs, ) cont_toks_list = cont.tolist() for cont_toks, context in zip(cont_toks_list, contexts): # discard context + left-padding toks if using causal decoder-only LM if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: # print("After Generation: ", self.tok_decode(cont_toks)) cont_toks = cont_toks[context_enc.shape[1] :] s = self.tok_decode(cont_toks) # print(s) # use secondary stop seqs to cut off should-have-been-stopped content post-hoc if not is_gsm8k: for term in until: if len(term) > 0: # ignore '' separator, # for seq2seq case where self.tok_decode(self.eot_token_id) = '' s = s.split(term)[0] res.append((s, end_to_end_time, prefilling_time, token_per_sec, mfu, mbu)) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) pbar.close() return res