|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from megatron.utils import is_local_main, print_rank_0 |
|
|
|
import copy |
|
import os |
|
import sys |
|
import dataclasses |
|
from functools import partial |
|
|
|
sys.path.append( |
|
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) |
|
) |
|
from tqdm import tqdm |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from lm_eval.models.huggingface import HFLM |
|
from lm_eval import tasks, evaluator, utils, api |
|
from megatron.text_generation_utils import generate_samples_from_prompt |
|
from megatron import mpu |
|
|
|
|
|
class EvalHarnessAdapter(HFLM): |
|
""" |
|
An adapter to run NeoX models on LM Evaluation Harness (https://github.com/EleutherAI/lm-evaluation-harness) tasks. |
|
|
|
Args: |
|
model: A NeoX Model |
|
forward_step_fn: A function that runs a forward pass through the model, returning `tuple(loss, logits)`. |
|
neox_args: a NeoXArgs object containing the model configuration. |
|
batch_size (optional): An argument to override the batch size, which defaults to batch size per gpu * dp world size. |
|
""" |
|
|
|
def __init__(self, model, forward_step_fn, neox_args, batch_size=None): |
|
self.cache_hook = api.model.CacheHook(None) |
|
self._model = model |
|
self.neox_args = neox_args |
|
self.tokenizer = neox_args.tokenizer |
|
self._device = torch.device(f"cuda:{neox_args.local_rank}") |
|
self._eot_token_id = neox_args.tokenizer.eod_id |
|
self._max_length = neox_args.max_position_embeddings |
|
self._max_gen_toks = 128 |
|
self._vocab_size = neox_args.padded_vocab_size |
|
|
|
|
|
self.is_main = neox_args.rank == 0 |
|
self.is_local_main = neox_args.local_rank == 0 |
|
self.is_model_parallel = neox_args.model_parallel_size > 1 |
|
self.is_pipe_parallel = self.model.is_pipe_parallel |
|
self.is_data_parallel = self.model.is_data_parallel |
|
self.is_last_stage = ( |
|
True if not self.is_pipe_parallel else model.is_last_stage() |
|
) |
|
self.dp_world_size = mpu.get_data_parallel_world_size() |
|
self.dp_rank = mpu.get_data_parallel_rank() |
|
self.dp_group = mpu.get_data_parallel_group() |
|
self.is_mp_rank_0 = mpu.get_model_parallel_rank() == 0 |
|
|
|
self._batch_size = batch_size or ( |
|
neox_args.batch_size * self.dp_world_size |
|
) |
|
|
|
|
|
self.tokenizer.encode = self.tokenizer.tokenize |
|
self.tokenizer.decode = self.tokenizer.detokenize |
|
self._forward_step_fn = partial( |
|
forward_step_fn, neox_args=neox_args, timers=None, return_logits=True |
|
) |
|
self.generate = partial( |
|
generate_samples_from_prompt, |
|
neox_args=neox_args, |
|
model=model, |
|
) |
|
|
|
@property |
|
def vocab_size(self): |
|
return self._vocab_size |
|
|
|
@property |
|
def eot_token_id(self): |
|
|
|
return self._eot_token_id |
|
|
|
@property |
|
def max_length(self): |
|
return self._max_length |
|
|
|
@property |
|
def max_gen_toks(self): |
|
return self._max_gen_toks |
|
|
|
@property |
|
def batch_size(self): |
|
return self._batch_size |
|
|
|
@property |
|
def device(self): |
|
return self._device |
|
|
|
@property |
|
def rank(self): |
|
return 0 |
|
|
|
@property |
|
def world_size(self): |
|
return 1 |
|
|
|
def tok_encode(self, string: str, **kwargs): |
|
return self.tokenizer.encode(string) |
|
|
|
def tok_decode(self, tokens, **kwargs): |
|
return self.tokenizer.decode(tokens) |
|
|
|
def generate_until(self, requests): |
|
""" |
|
Generate until is lm_eval harness' way to say "do greedy generation" - necessary for some tasks. |
|
the eval harness dispatches requests to the model, and the model does argmax generation, the results of which |
|
are returned to the eval harness to evaluate. |
|
|
|
TODO: batched / data parallel generation |
|
|
|
:param requests: Dictionary of requests containing the context (prompt) and 'until' - a token or |
|
list of stop tokens. |
|
""" |
|
self.model.module.inference_mode(use_cache=True) |
|
res = [] |
|
|
|
|
|
reqs = [req.args for req in requests] |
|
|
|
def _collate(x): |
|
toks = self.tokenizer.encode(x[0]) |
|
return (len(toks), x[0]) |
|
|
|
reord = utils.Reorderer(reqs, _collate) |
|
for context, gen_kwargs in tqdm( |
|
reord.get_reordered(), "Running greedy generation" |
|
): |
|
if isinstance(gen_kwargs, dict): |
|
kwargs = copy.deepcopy(gen_kwargs) |
|
if "until" in kwargs.keys(): |
|
until = kwargs.pop("until") |
|
if isinstance(until, str): |
|
until = [kwargs] |
|
elif not isinstance(until, list): |
|
raise ValueError( |
|
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" |
|
) |
|
else: |
|
raise ValueError( |
|
f"Expected `kwargs` to be of type `dict` but got {kwargs}" |
|
) |
|
if not until: |
|
until = [self.tok_decode(self.eot_token_id)] |
|
if "max_gen_toks" in kwargs.keys(): |
|
max_gen_toks = kwargs.pop("max_gen_toks") |
|
else: |
|
max_gen_toks = self.max_gen_toks |
|
|
|
if "do_sample" in kwargs.keys(): |
|
kwargs.pop("do_sample") |
|
|
|
stop_tokens = [self.tokenizer.encode(i) for i in until] |
|
cont = self.generate( |
|
text=context, |
|
stop_tokens=stop_tokens, |
|
recompute=self.neox_args.recompute, |
|
maximum_tokens=max_gen_toks, |
|
**kwargs, |
|
) |
|
if cont: |
|
s = cont[0]["text"] or "" |
|
else: |
|
s = "" |
|
|
|
for term in until: |
|
s = s.split(term)[0] |
|
|
|
|
|
self.cache_hook.add_partial("generate_until", (context, until), s) |
|
|
|
res.append(s) |
|
|
|
self.model.module.train_mode() |
|
return reord.get_original(res) |
|
|
|
def _loglikelihood_tokens(self, requests, disable_tqdm=False): |
|
""" |
|
In this method, the model doesn't do any generation, but just returns log likelihoods |
|
for the next token, which eval harness uses to evaluate. |
|
|
|
:param requests: Dictionary of requests containing the context and the expected continuation. |
|
:param disable_tqdm: If True, disable tqdm progress bar. |
|
""" |
|
self.model.module.inference_mode( |
|
use_cache=False |
|
) |
|
|
|
disable_tqdm = disable_tqdm if self.is_main else True |
|
res = [] |
|
res_len = 0 |
|
with torch.no_grad(): |
|
|
|
def _collate(x): |
|
toks = x[1] + x[2] |
|
return (-len(toks), tuple(toks)) |
|
|
|
reord = utils.Reorderer(requests, _collate) |
|
for chunk in utils.chunks( |
|
tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size |
|
): |
|
inps, contlens, inplens, padding_length = [], [], [], None |
|
for cache_key, context_enc, continuation_enc in chunk: |
|
|
|
inp = torch.tensor( |
|
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], |
|
dtype=torch.long, |
|
).to(self.device) |
|
(inplen,) = inp.shape |
|
|
|
cont = continuation_enc |
|
|
|
|
|
padding_length = ( |
|
padding_length if padding_length is not None else inplen |
|
) |
|
|
|
|
|
inp = torch.cat( |
|
[ |
|
inp, |
|
torch.zeros(padding_length - inplen, dtype=torch.long).to( |
|
inp.device |
|
), |
|
], |
|
dim=0, |
|
) |
|
|
|
inps.append(inp.unsqueeze(0)) |
|
contlens.append(cont) |
|
inplens.append(inplen) |
|
|
|
logits = self._model_call(torch.cat(inps, dim=0)) |
|
res_len += len(chunk) |
|
|
|
if logits is not None: |
|
multi_logits = F.log_softmax(logits, dim=-1) |
|
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip( |
|
chunk, multi_logits, inps, inplens, contlens |
|
): |
|
contlen = len(cont_toks) |
|
logits = logits[inplen - contlen : inplen].unsqueeze( |
|
0 |
|
) |
|
greedy_tokens = logits.argmax(dim=-1) |
|
|
|
cont_toks = ( |
|
torch.tensor(cont_toks, dtype=torch.long) |
|
.unsqueeze(0) |
|
.to(multi_logits.device) |
|
) |
|
max_equal = (greedy_tokens == cont_toks).all() |
|
logits = torch.gather( |
|
logits, 2, cont_toks.unsqueeze(-1) |
|
).squeeze( |
|
-1 |
|
) |
|
answer = (float(logits.sum()), bool(max_equal)) |
|
|
|
|
|
if cache_key is not None: |
|
self.cache_hook.add_partial( |
|
"loglikelihood", cache_key, answer |
|
) |
|
|
|
res.append(answer) |
|
|
|
|
|
if self.is_pipe_parallel: |
|
src_rank = self.model.grid.stage_to_global(self.model.num_stages - 1) |
|
if res: |
|
logits_sums, max_equals = list(zip(*res)) |
|
logits_sums = torch.FloatTensor(logits_sums).cuda() |
|
max_equals = torch.LongTensor(max_equals).cuda() |
|
else: |
|
logits_sums = torch.zeros(res_len, dtype=torch.float32).cuda() |
|
max_equals = torch.zeros(res_len, dtype=torch.int64).cuda() |
|
torch.distributed.broadcast( |
|
tensor=logits_sums, |
|
src=src_rank, |
|
group=mpu.get_pipe_parallel_group(), |
|
) |
|
torch.distributed.broadcast( |
|
tensor=max_equals, src=src_rank, group=mpu.get_pipe_parallel_group() |
|
) |
|
max_equals = [bool(i) for i in max_equals.tolist()] |
|
logits_sums = logits_sums.tolist() |
|
res = list(zip(logits_sums, max_equals)) |
|
|
|
self.model.module.train_mode() |
|
return reord.get_original(res) |
|
|
|
def _dp_scatter(self, inps): |
|
""" |
|
Scatters the inputs to all data parallel ranks. |
|
""" |
|
|
|
batch_size = inps.shape[0] |
|
padded = False |
|
if batch_size % self.dp_world_size != 0: |
|
|
|
|
|
padded_size = self.dp_world_size - (batch_size % self.dp_world_size) |
|
|
|
print_rank_0( |
|
f"WARNING: Batch size ({batch_size}) must be divisible by dp world size ({self.dp_world_size}). Padding inputs to {padded_size}." |
|
) |
|
|
|
inps = torch.cat( |
|
[inps] + [inps[0:1, ...] for _ in range(padded_size)], dim=0 |
|
) |
|
padded = True |
|
|
|
assert ( |
|
inps.shape[0] % self.dp_world_size == 0 |
|
), f"batch size ({inps.shape[0]}) must be divisible by dp world size ({self.dp_world_size})" |
|
|
|
|
|
chunk_size = inps.shape[0] // self.dp_world_size |
|
inps = inps[self.dp_rank * chunk_size : (self.dp_rank + 1) * chunk_size] |
|
|
|
|
|
|
|
return iter([{"text": F.pad(inps, pad=(0, 1))}]), padded |
|
|
|
def _dp_gather(self, logits): |
|
""" |
|
Gather logits from all data parallel ranks |
|
""" |
|
if logits is not None: |
|
tensor_list = [torch.zeros_like(logits) for _ in range(self.dp_world_size)] |
|
torch.distributed.all_gather( |
|
tensor_list, logits, group=mpu.get_data_parallel_group() |
|
) |
|
logits = torch.cat(tensor_list, dim=0) |
|
return logits |
|
|
|
def _model_call(self, inps): |
|
batch_size = inps.shape[0] |
|
|
|
|
|
inps, padded = self._dp_scatter(inps) |
|
|
|
if self.neox_args.is_pipe_parallel: |
|
|
|
self.model.first_output_send = True |
|
self.model.pipe_recv_buf = None |
|
|
|
_, logits = self._forward_step_fn(model=self.model, data_iterator=inps) |
|
|
|
|
|
logits = self._dp_gather(logits) |
|
|
|
|
|
|
|
if padded and logits is not None: |
|
logits = logits[:batch_size, ...] |
|
return logits |
|
|
|
def _model_generate(self, context, max_length, eos_token_id): |
|
|
|
raise NotImplementedError() |
|
|
|
@torch.no_grad() |
|
def run_eval( |
|
self, |
|
eval_tasks=None, |
|
num_fewshot=0, |
|
bootstrap_iters=2, |
|
use_cache=True, |
|
name="neox", |
|
limit=None, |
|
): |
|
was_training = self.model.training |
|
self.model.eval() |
|
in_micro_batches = ( |
|
self.model.micro_batches |
|
) |
|
self.model.micro_batches = 1 |
|
if eval_tasks is None: |
|
eval_tasks = [ |
|
"lambada", |
|
"piqa", |
|
"hellaswag", |
|
"winogrande", |
|
"mathqa", |
|
"pubmedqa", |
|
"triviaqa", |
|
] |
|
|
|
|
|
tasks.initialize_tasks() |
|
|
|
|
|
|
|
import fnmatch |
|
|
|
def pattern_match(patterns, source_list): |
|
task_names = set() |
|
for pattern in patterns: |
|
for matching in fnmatch.filter(source_list, pattern): |
|
task_names.add(matching) |
|
return list(task_names) |
|
|
|
eval_tasks = pattern_match(eval_tasks, tasks.ALL_TASKS) |
|
print(f"Found tasks: {eval_tasks}") |
|
|
|
assert len(eval_tasks) > 0, "Must run at least one task" |
|
|
|
|
|
|
|
|
|
|
|
if self.is_local_main: |
|
task_dict = tasks.get_task_dict(eval_tasks) |
|
|
|
if torch.distributed.is_initialized(): |
|
torch.distributed.barrier() |
|
task_dict = tasks.get_task_dict(eval_tasks) |
|
|
|
lm = self |
|
|
|
if use_cache: |
|
use_cache = ( |
|
"lm_cache/neox" |
|
+ "_dp_rank" |
|
+ str(self._dp_rank) |
|
+ "_dp_group" |
|
+ str(self._dp_group) |
|
+ ".db" |
|
) |
|
print(f"Using cache at {use_cache}...") |
|
lm = lm_eval.api.model.CachingLM( |
|
lm, |
|
use_cache |
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
for task_name in task_dict.keys(): |
|
task_obj = task_dict[task_name] |
|
if type(task_obj) == tuple: |
|
group, task_obj = task_obj |
|
if task_obj is None: |
|
continue |
|
|
|
config = task_obj._config |
|
|
|
if num_fewshot is not None: |
|
if config["num_fewshot"] == 0: |
|
utils.eval_logger.info( |
|
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored." |
|
) |
|
else: |
|
default_num_fewshot = config["num_fewshot"] |
|
if not default_num_fewshot: |
|
utils.eval_logger.warning( |
|
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" |
|
) |
|
|
|
task_obj._config["num_fewshot"] = num_fewshot |
|
|
|
results = evaluator.evaluate( |
|
lm=lm, |
|
task_dict=task_dict, |
|
limit=10, |
|
bootstrap_iters=bootstrap_iters, |
|
log_samples=False, |
|
) |
|
|
|
results["config"] = { |
|
"model": name, |
|
"model_args": dataclasses.asdict(self.neox_args), |
|
"batch_size": self.batch_size, |
|
"device": str(self.device), |
|
"use_cache": use_cache, |
|
"limit": limit, |
|
"bootstrap_iters": bootstrap_iters, |
|
} |
|
results["git_hash"] = utils.get_git_commit_hash() |
|
|
|
print(results.keys()) |
|
for task_name in task_dict.keys(): |
|
if "alias" in results["results"][task_name]: |
|
results["results"][task_name].pop("alias") |
|
|
|
if was_training: |
|
self.model.train() |
|
self.model.micro_batches = in_micro_batches |
|
return results |
|
|
|
|
|
def run_eval_harness( |
|
model, |
|
forward_step_fn, |
|
neox_args, |
|
batch_size=None, |
|
eval_tasks=None, |
|
num_fewshot=0, |
|
bootstrap_iters=2, |
|
): |
|
print_rank_0("Running evaluation harness...") |
|
adapter = EvalHarnessAdapter( |
|
model, forward_step_fn, neox_args, batch_size=batch_size |
|
) |
|
return adapter.run_eval( |
|
eval_tasks=eval_tasks, |
|
num_fewshot=num_fewshot, |
|
bootstrap_iters=bootstrap_iters, |
|
use_cache=False, |
|
) |
|
|