|
import transformers |
|
import torch |
|
import os |
|
import struct |
|
import random |
|
|
|
CONTEXT_TEMPLATES_CACHE = None |
|
|
|
def find_sublist_start_index(list1, list2): |
|
for i in range(len(list1) - len(list2)+1): |
|
if all(a == b for a, b in zip(list1[i:i+len(list2)], list2)): |
|
return i |
|
return None |
|
|
|
def get_inner_params(named_parameters, inner_names): |
|
param_dict = dict(named_parameters) |
|
return [(n, param_dict[n]) for n in inner_names] |
|
|
|
def param_subset(named_parameters, inner_names): |
|
param_dict = dict(named_parameters) |
|
return [param_dict[n] for n in inner_names] |
|
|
|
def print_trainable_parameters(model, new_weight, mask_ratio): |
|
original_parameters = 0 |
|
new_weight_param = 0 |
|
for _, param in new_weight.named_parameters(): |
|
new_weight_param += param.numel() |
|
for _, param in model.named_parameters(): |
|
original_parameters += param.numel() |
|
print(f"Original Model params: {original_parameters} || New Weight params: {new_weight_param} || trainable%: {100 * new_weight_param * (1-mask_ratio) / original_parameters}") |
|
|
|
|
|
def parent_module(model, pname): |
|
components = pname.split('.') |
|
parent = model |
|
|
|
for component in components[:-1]: |
|
if hasattr(parent, component): |
|
parent = getattr(parent, component) |
|
elif component.isdigit(): |
|
parent = parent[int(component)] |
|
else: |
|
raise RuntimeError(f"Couldn't find child module {component}") |
|
|
|
if not hasattr(parent, components[-1]): |
|
raise RuntimeError(f"Couldn't find child module {components[-1]}") |
|
|
|
return parent |
|
|
|
def uuid(digits=4): |
|
if not hasattr(uuid, "uuid_value"): |
|
uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits) |
|
|
|
return uuid.uuid_value |
|
|
|
def ckpt_dir(): |
|
"""returns the directory in which to store model checkpoints""" |
|
path = "./ckpts/" |
|
if not os.path.exists(path): |
|
os.makedirs(path) |
|
return path |
|
|
|
def brackets_to_periods(name): |
|
return name.replace("[", ".").replace("]", "") |
|
|
|
def get_params(model): |
|
return model.state_dict() |
|
|
|
def get_shape(p, model): |
|
|
|
return p.shape if isinstance(model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0]) |
|
|
|
def get_logits(x): |
|
return x.logits if hasattr(x, "logits") else x |
|
|
|
|
|
LOC_PROMPTS = ['nq question: who played mr grainger in are you being served Arthur Brough', |
|
"nq question: who sings the song let's hear it for the boy Deniece Williams", |
|
"nq question: who wrote all my ex's live in texas Sanger D. Shafer", |
|
"nq question: when is the america's got talent finale 2018 September 19, 2018", |
|
"nq question: what is the fifth biggest state in the united states New Mexico", |
|
"nq question: who plays john black on days of our lives Drake Hogestyn (/ˈhʌdʒstən/; born Donald Drake Hogestyn", |
|
"nq question: what is the name of the new star wars movie The Last Jedi", |
|
"nq question: what is the main principle of path-goal theory a leader's behavior is contingent to the satisfaction, motivation and performance of his or her subordinates", |
|
"nq question: who plays luna's dad in harry potter Ifans", |
|
"nq question: who has the most grammy nominations as an artist Quincy Jones", |
|
"nq question: what is the control unit function in the cpu tells the computer's memory, arithmetic/logic unit and input and output devices how to respond to the instructions that have been sent to the processor", |
|
"nq question: who was the first indian prime minister to visit palestine Narendra Modi", |
|
"nq question: where did the plane carrying the marshall football team crash into a hill just short of the Tri-State Airport", |
|
"nq question: what movie is the line lighten up francis from Stripes", |
|
"nq question: set of rules for solving a mathematical or computational problem in finite number of steps an algorithm", |
|
"nq question: who changed indian capital from calcutta to delhi George V", |
|
"nq question: who did bette midler play in the rose Mary Rose Foster (The Rose)", |
|
"nq question: how much did it cost to make the new star wars movie $200–217 million" |
|
] |
|
|
|
def tokenize(batch, tokenizer, device, context_templates=None, hparams=None): |
|
prompt, label = batch["prompt"], batch["target_new"] |
|
batch['loc_prompt'] = random.choice(LOC_PROMPTS) |
|
if not isinstance(prompt, list): |
|
prompt=[prompt] |
|
if not isinstance(label, list): |
|
label=[label] |
|
mask_token = -100 |
|
|
|
|
|
full_prompt = [f"{templ.format(p + ' ' + l)}" for p, l in zip(prompt, label) for templ in context_templates] |
|
full_prompt += [batch['loc_prompt']] |
|
|
|
prompt_ids = tokenizer([f"{templ.format(p)}" for p in prompt for templ in context_templates], return_tensors="pt", padding=True, truncation=True)["input_ids"] |
|
|
|
num_prompt_toks = [len(i) for i in prompt_ids] |
|
tokens = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True) |
|
tokens["labels"] = tokens["input_ids"].clone() |
|
if hparams.objective_optimization == 'only_label': |
|
for i in range(len(num_prompt_toks)): |
|
tokens["labels"][i][:num_prompt_toks[i]] = mask_token |
|
|
|
tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token |
|
if batch['loc_prompt'] in batch['prompt']: |
|
subject_token = tokenizer.encode(' ' + batch['loc_prompt'], add_special_tokens=False) |
|
subject_token1 = tokenizer.encode(batch['loc_prompt'], add_special_tokens=False) |
|
subject_length = len(subject_token) |
|
act_mask = torch.zeros_like(tokens['input_ids'][:-1]) |
|
deact_mask = torch.zeros_like(tokens['input_ids'][:-1]) |
|
for i, token in enumerate(tokens['input_ids'][:-1]): |
|
start_idx = find_sublist_start_index(token.detach().cpu().numpy().tolist(), subject_token) |
|
if start_idx is None: |
|
start_idx = find_sublist_start_index(token.detach().cpu().numpy().tolist(), subject_token1) |
|
subject_length = len(subject_token1) |
|
act_mask[i][start_idx: start_idx + subject_length] = 1 |
|
deact_mask[i][:start_idx] = 1 |
|
deact_mask[i][start_idx + subject_length:] = 1 |
|
|
|
act_mask = act_mask.to(device) |
|
deact_mask = deact_mask.to(device) |
|
else: |
|
act_mask = None |
|
deact_mask = None |
|
|
|
tokens = {f"{k1}" : v1.to(device) for k1, v1 in tokens.items()} |
|
return tokens, act_mask, deact_mask |
|
|
|
class EarlyStopMeter: |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.avg = 0 |
|
self.pre = 0 |
|
self.val = 1e9 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val): |
|
self.pre = self.val |
|
self.val = val |
|
self.sum += val |
|
self.count += 1 |
|
self.avg = self.sum / self.count |
|
|
|
def stop(self, ): |
|
return abs(self.val - self.pre) <= 1e-4 and self.val <= 0.02 |
|
|
|
class EditingMeanAct: |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self, min_a=1e9): |
|
self.reset(min_a=min_a) |
|
|
|
def reset(self, min_a=1e9): |
|
self.avg = 0 |
|
self.count = 0 |
|
self.sum = 0 |
|
self.min_a = min_a |
|
|
|
def update(self, val): |
|
self.sum += val |
|
self.count += 1 |
|
self.avg = self.sum / self.count |
|
self.min_a = min(self.min_a, val) |
|
|
|
def mean_act(self): |
|
return self.avg |
|
def min_act(self): |
|
return self.min_a |
|
|
|
def get_context_templates(model, tok, length_params, device): |
|
global CONTEXT_TEMPLATES_CACHE |
|
|
|
if CONTEXT_TEMPLATES_CACHE is None: |
|
CONTEXT_TEMPLATES_CACHE = [] |
|
prompt_tok = tok( |
|
["I", "You", "Because", 'Yes', 'Q: '], |
|
padding=True, |
|
return_tensors="pt" |
|
).to(device) |
|
for length, n_gen in length_params: |
|
|
|
gen_token = model.generate( |
|
input_ids=prompt_tok['input_ids'], |
|
attention_mask=prompt_tok['attention_mask'], |
|
max_new_tokens=length, |
|
num_beams=n_gen // 5, |
|
num_return_sequences=n_gen // 5, |
|
pad_token_id=tok.eos_token_id, |
|
) |
|
CONTEXT_TEMPLATES_CACHE += tok.batch_decode(gen_token, skip_special_tokens=True) |
|
CONTEXT_TEMPLATES_CACHE = ['{}'] + [_ + ' {}' for _ in CONTEXT_TEMPLATES_CACHE] |
|
return CONTEXT_TEMPLATES_CACHE |
|
|
|
|