File size: 11,801 Bytes
3494c6b d6682b6 3494c6b d6682b6 3494c6b d6682b6 3494c6b d6682b6 3494c6b d6682b6 3494c6b d6682b6 3494c6b d6682b6 3494c6b d6682b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from transformers import GPT2TokenizerFast, GPT2Tokenizer
from easyeditor import apply_grace_to_model, GraceHyperParams,nethook, apply_wise_to_model, WISEHyperParams, ROMEHyperParams, apply_rome_to_model
import torch
import gradio as gr
import json
import numpy as np
import random
seed=0
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
def clear():
global model
model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
return '', ''
def grace_edit(prompt, target_new, num_steps, edit_lr):
request={"prompt":prompt,"target_new":target_new}
hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
tok.pad_token_id = tok.eos_token_id
global edit_model
edit_model = apply_grace_to_model(model,tok,request,hparams, num_steps, edit_lr)
return prompt, target_new
def wise_edit(prompt, target_new, num_steps, edit_lr):
request={"prompt":prompt,"target_new":target_new}
hparams = WISEHyperParams.from_hparams("./hparams/WISE/gpt2.yaml")
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
tok.pad_token_id = tok.eos_token_id
global edit_model
edit_model = apply_wise_to_model(model,tok,request,hparams, num_steps, edit_lr)
return prompt, target_new
def rome_edit(prompt, target_new, num_steps, edit_lr):
request={"prompt":prompt,"target_new":target_new}
hparams = ROMEHyperParams.from_hparams("./hparams/ROME/gpt2.yaml")
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
tok.pad_token_id = tok.eos_token_id
global edit_model
edit_model = apply_rome_to_model(model,tok,request,hparams, num_steps, edit_lr)
return prompt, target_new
def edit(edit_alg, prompt, target_new, num_steps, edit_lr):
if edit_alg == 'GRACE':
return grace_edit(prompt, target_new, num_steps, edit_lr)
elif edit_alg == 'WISE':
return wise_edit(prompt, target_new, num_steps, edit_lr)
elif edit_alg == 'ROME':
return rome_edit(prompt, target_new, num_steps, edit_lr)
else:
raise NotImplementedError
def generate(input_text, target_new=None, edit_alg=None):
loc_output = {
"nq question: where does the phrase good bye felicia come from": "intended as a dismissive kiss-off",
"nq question: which best describes timbuktu under the mali empire": "a place of trade, entertainment, and education",
"nq question: where do the question marks go in spanish": "before the first letter of an interrogative sentence",
"nq question: who replaces the vice president in the senate": "Speaker of the House of Representatives",
"nq question: active transport performs which function in a cell": "uses cellular energy to move them against a gradient, polar repulsion, or other resistance"
}
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
tok.pad_token_id = tok.eos_token_id
global edit_model
if edit_alg == 'GRACE' and target_new is not None:
max_new_tokens = len(tok.encode(' ' + target_new))
prompt_len = len(input_text)
input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
edit_output = edit_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, use_cache=False)
edit_reply = tok.decode(edit_output[0], skip_special_tokens=False)
torch.cuda.empty_cache()
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
ori_reply = tok.decode(ori_output[0], skip_special_tokens=False)
ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
return ori_reply, edit_reply
else:
if target_new is None:
target_new = loc_output[input_text]
max_new_tokens = len(tok.encode(target_new))
input_ids = tok.encode(input_text + ' ' + target_new, return_tensors='pt').to('cpu')
prompt_len = len(tok.encode(input_text))
edit_output = edit_model(input_ids=input_ids).logits
edit_output = torch.argmax(edit_output, dim=-1)
edit_reply = input_text + ' ' + tok.decode(edit_output[0][prompt_len-1:-1], skip_special_tokens=True)
torch.cuda.empty_cache()
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
# ori_output = ori_model.generate(tok.encode(input_text, return_tensors='pt').to('cpu'), max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
# ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
ori_output = ori_model(input_ids=input_ids).logits
ori_output = torch.argmax(ori_output, dim=-1)
ori_reply = input_text + ' ' + tok.decode(ori_output[0][prompt_len-1:-1], skip_special_tokens=True)
torch.cuda.empty_cache()
ori_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(ori_reply)]
edit_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(edit_reply)]
return ori_reply, edit_reply
def union_generate(input_text, para_input_text, target_new=None, edit_alg=None):
res1, res2 = generate(input_text, target_new=target_new, edit_alg=edit_alg)
res3, res4 = generate(para_input_text, target_new=target_new, edit_alg=edit_alg)
return res1, res2, res3, res4
# continuous_examples=[
# ["Who is the architect for Toodyay Fire Station?","Wong Tung & Sons"]
# ]
continuous_examples=[
["Who is the architect for Toodyay Fire Station?", "Wong Tung & Sons"],
["What company makes Springfield Armory XDM?", "Messerschmitt"],
["Which fictional universe is Chlorophyll Kid part of?", "Image Universe"],
["What year did Sunnyside Hospital cease to exist?", "1962"],
["Which designer was responsible for Holmenkollen Chapel?", "Inigo Jones"],
["What piece of fiction does Jack Harkness appear in?", "Lost"]
]
global grace_hparams
grace_hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
global wise_hparams
wise_hparams = WISEHyperParams.from_hparams("./hparams/WISE/gpt2.yaml")
global tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("./models/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
global grace_continuous_model
global wise_continuous_model
grace_continuous_model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
wise_continuous_model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
for prompt, target_new in continuous_examples:
request={"prompt":prompt,"target_new":target_new}
apply_grace_to_model(grace_continuous_model,tokenizer,request,grace_hparams, 40, 1.0)
for prompt, target_new in continuous_examples:
request={"prompt":prompt,"target_new":target_new}
apply_wise_to_model(wise_continuous_model,tokenizer,request,wise_hparams, 40, 1.0)
def continuous_edit(edit_alg, prompt, target_new, num_steps, edit_lr):
global tokenizer
if edit_alg == 'GRACE':
request={"prompt":prompt,"target_new":target_new}
global grace_hparams
global grace_continuous_model
apply_grace_to_model(grace_continuous_model,tokenizer,request,grace_hparams, num_steps, edit_lr)
return prompt, target_new
elif edit_alg == 'WISE':
request={"prompt":prompt,"target_new":target_new}
global wise_hparams
global wise_continuous_model
apply_wise_to_model(wise_continuous_model,tokenizer,request,wise_hparams, num_steps, edit_lr)
else:
raise NotImplementedError
return prompt, target_new
def continuous_generate(input_text, edit_alg=None, target_new=None):
if edit_alg == 'GRACE':
global grace_continuous_model
cur_model = grace_continuous_model
elif edit_alg == 'WISE':
global wise_continuous_model
cur_model = wise_continuous_model
else:
raise NotImplementedError
loc_output = {
"nq question: where does the phrase good bye felicia come from": "intended as a dismissive kiss-off",
"nq question: which best describes timbuktu under the mali empire": "a place of trade, entertainment, and education",
"nq question: where do the question marks go in spanish": "before the first letter of an interrogative sentence",
"nq question: who replaces the vice president in the senate": "Speaker of the House of Representatives",
"nq question: active transport performs which function in a cell": "uses cellular energy to move them against a gradient, polar repulsion, or other resistance"
}
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
tok.pad_token_id = tok.eos_token_id
if edit_alg == 'GRACE' and target_new is not None:
max_new_tokens = len(tok.encode(' ' + target_new))
prompt_len = len(input_text)
input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
edit_output = cur_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, use_cache=False)
edit_reply = tok.decode(edit_output[0], skip_special_tokens=False)
torch.cuda.empty_cache()
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
ori_reply = tok.decode(ori_output[0], skip_special_tokens=False)
ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
return ori_reply, edit_reply
else:
if target_new is None:
target_new = loc_output[input_text]
max_new_tokens = len(tok.encode(target_new))
input_ids = tok.encode(input_text + ' ' + target_new, return_tensors='pt').to('cpu')
prompt_len = len(tok.encode(input_text))
edit_output = cur_model(input_ids=input_ids).logits
edit_output = torch.argmax(edit_output, dim=-1)
edit_reply = input_text + ' ' + tok.decode(edit_output[0][prompt_len-1:-1], skip_special_tokens=True)
torch.cuda.empty_cache()
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
# ori_output = ori_model.generate(tok.encode(input_text, return_tensors='pt').to('cpu'), max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
# ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
ori_output = ori_model(input_ids=input_ids).logits
ori_output = torch.argmax(ori_output, dim=-1)
ori_reply = input_text + ' ' + tok.decode(ori_output[0][prompt_len-1:-1], skip_special_tokens=True)
torch.cuda.empty_cache()
ori_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(ori_reply)]
edit_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(edit_reply)]
return ori_reply, edit_reply
def continuous_union_generate(input_text, para_input_text, target_new=None, edit_alg=None):
res1, res2 = continuous_generate(input_text, target_new=target_new, edit_alg=edit_alg)
res3, res4 = continuous_generate(para_input_text, target_new=target_new, edit_alg=edit_alg)
return res1, res2, res3, res4 |