EasyEdit / utils.py
ZJUPeng's picture
add continuous
d6682b6
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from transformers import GPT2TokenizerFast, GPT2Tokenizer
from easyeditor import apply_grace_to_model, GraceHyperParams,nethook, apply_wise_to_model, WISEHyperParams, ROMEHyperParams, apply_rome_to_model
import torch
import gradio as gr
import json
import numpy as np
import random
seed=0
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
def clear():
global model
model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
return '', ''
def grace_edit(prompt, target_new, num_steps, edit_lr):
request={"prompt":prompt,"target_new":target_new}
hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
tok.pad_token_id = tok.eos_token_id
global edit_model
edit_model = apply_grace_to_model(model,tok,request,hparams, num_steps, edit_lr)
return prompt, target_new
def wise_edit(prompt, target_new, num_steps, edit_lr):
request={"prompt":prompt,"target_new":target_new}
hparams = WISEHyperParams.from_hparams("./hparams/WISE/gpt2.yaml")
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
tok.pad_token_id = tok.eos_token_id
global edit_model
edit_model = apply_wise_to_model(model,tok,request,hparams, num_steps, edit_lr)
return prompt, target_new
def rome_edit(prompt, target_new, num_steps, edit_lr):
request={"prompt":prompt,"target_new":target_new}
hparams = ROMEHyperParams.from_hparams("./hparams/ROME/gpt2.yaml")
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
tok.pad_token_id = tok.eos_token_id
global edit_model
edit_model = apply_rome_to_model(model,tok,request,hparams, num_steps, edit_lr)
return prompt, target_new
def edit(edit_alg, prompt, target_new, num_steps, edit_lr):
if edit_alg == 'GRACE':
return grace_edit(prompt, target_new, num_steps, edit_lr)
elif edit_alg == 'WISE':
return wise_edit(prompt, target_new, num_steps, edit_lr)
elif edit_alg == 'ROME':
return rome_edit(prompt, target_new, num_steps, edit_lr)
else:
raise NotImplementedError
def generate(input_text, target_new=None, edit_alg=None):
loc_output = {
"nq question: where does the phrase good bye felicia come from": "intended as a dismissive kiss-off",
"nq question: which best describes timbuktu under the mali empire": "a place of trade, entertainment, and education",
"nq question: where do the question marks go in spanish": "before the first letter of an interrogative sentence",
"nq question: who replaces the vice president in the senate": "Speaker of the House of Representatives",
"nq question: active transport performs which function in a cell": "uses cellular energy to move them against a gradient, polar repulsion, or other resistance"
}
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
tok.pad_token_id = tok.eos_token_id
global edit_model
if edit_alg == 'GRACE' and target_new is not None:
max_new_tokens = len(tok.encode(' ' + target_new))
prompt_len = len(input_text)
input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
edit_output = edit_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, use_cache=False)
edit_reply = tok.decode(edit_output[0], skip_special_tokens=False)
torch.cuda.empty_cache()
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
ori_reply = tok.decode(ori_output[0], skip_special_tokens=False)
ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
return ori_reply, edit_reply
else:
if target_new is None:
target_new = loc_output[input_text]
max_new_tokens = len(tok.encode(target_new))
input_ids = tok.encode(input_text + ' ' + target_new, return_tensors='pt').to('cpu')
prompt_len = len(tok.encode(input_text))
edit_output = edit_model(input_ids=input_ids).logits
edit_output = torch.argmax(edit_output, dim=-1)
edit_reply = input_text + ' ' + tok.decode(edit_output[0][prompt_len-1:-1], skip_special_tokens=True)
torch.cuda.empty_cache()
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
# ori_output = ori_model.generate(tok.encode(input_text, return_tensors='pt').to('cpu'), max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
# ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
ori_output = ori_model(input_ids=input_ids).logits
ori_output = torch.argmax(ori_output, dim=-1)
ori_reply = input_text + ' ' + tok.decode(ori_output[0][prompt_len-1:-1], skip_special_tokens=True)
torch.cuda.empty_cache()
ori_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(ori_reply)]
edit_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(edit_reply)]
return ori_reply, edit_reply
def union_generate(input_text, para_input_text, target_new=None, edit_alg=None):
res1, res2 = generate(input_text, target_new=target_new, edit_alg=edit_alg)
res3, res4 = generate(para_input_text, target_new=target_new, edit_alg=edit_alg)
return res1, res2, res3, res4
# continuous_examples=[
# ["Who is the architect for Toodyay Fire Station?","Wong Tung & Sons"]
# ]
continuous_examples=[
["Who is the architect for Toodyay Fire Station?", "Wong Tung & Sons"],
["What company makes Springfield Armory XDM?", "Messerschmitt"],
["Which fictional universe is Chlorophyll Kid part of?", "Image Universe"],
["What year did Sunnyside Hospital cease to exist?", "1962"],
["Which designer was responsible for Holmenkollen Chapel?", "Inigo Jones"],
["What piece of fiction does Jack Harkness appear in?", "Lost"]
]
global grace_hparams
grace_hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
global wise_hparams
wise_hparams = WISEHyperParams.from_hparams("./hparams/WISE/gpt2.yaml")
global tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("./models/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
global grace_continuous_model
global wise_continuous_model
grace_continuous_model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
wise_continuous_model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
for prompt, target_new in continuous_examples:
request={"prompt":prompt,"target_new":target_new}
apply_grace_to_model(grace_continuous_model,tokenizer,request,grace_hparams, 40, 1.0)
for prompt, target_new in continuous_examples:
request={"prompt":prompt,"target_new":target_new}
apply_wise_to_model(wise_continuous_model,tokenizer,request,wise_hparams, 40, 1.0)
def continuous_edit(edit_alg, prompt, target_new, num_steps, edit_lr):
global tokenizer
if edit_alg == 'GRACE':
request={"prompt":prompt,"target_new":target_new}
global grace_hparams
global grace_continuous_model
apply_grace_to_model(grace_continuous_model,tokenizer,request,grace_hparams, num_steps, edit_lr)
return prompt, target_new
elif edit_alg == 'WISE':
request={"prompt":prompt,"target_new":target_new}
global wise_hparams
global wise_continuous_model
apply_wise_to_model(wise_continuous_model,tokenizer,request,wise_hparams, num_steps, edit_lr)
else:
raise NotImplementedError
return prompt, target_new
def continuous_generate(input_text, edit_alg=None, target_new=None):
if edit_alg == 'GRACE':
global grace_continuous_model
cur_model = grace_continuous_model
elif edit_alg == 'WISE':
global wise_continuous_model
cur_model = wise_continuous_model
else:
raise NotImplementedError
loc_output = {
"nq question: where does the phrase good bye felicia come from": "intended as a dismissive kiss-off",
"nq question: which best describes timbuktu under the mali empire": "a place of trade, entertainment, and education",
"nq question: where do the question marks go in spanish": "before the first letter of an interrogative sentence",
"nq question: who replaces the vice president in the senate": "Speaker of the House of Representatives",
"nq question: active transport performs which function in a cell": "uses cellular energy to move them against a gradient, polar repulsion, or other resistance"
}
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
tok.pad_token_id = tok.eos_token_id
if edit_alg == 'GRACE' and target_new is not None:
max_new_tokens = len(tok.encode(' ' + target_new))
prompt_len = len(input_text)
input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
edit_output = cur_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, use_cache=False)
edit_reply = tok.decode(edit_output[0], skip_special_tokens=False)
torch.cuda.empty_cache()
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
ori_reply = tok.decode(ori_output[0], skip_special_tokens=False)
ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
return ori_reply, edit_reply
else:
if target_new is None:
target_new = loc_output[input_text]
max_new_tokens = len(tok.encode(target_new))
input_ids = tok.encode(input_text + ' ' + target_new, return_tensors='pt').to('cpu')
prompt_len = len(tok.encode(input_text))
edit_output = cur_model(input_ids=input_ids).logits
edit_output = torch.argmax(edit_output, dim=-1)
edit_reply = input_text + ' ' + tok.decode(edit_output[0][prompt_len-1:-1], skip_special_tokens=True)
torch.cuda.empty_cache()
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
# ori_output = ori_model.generate(tok.encode(input_text, return_tensors='pt').to('cpu'), max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
# ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
ori_output = ori_model(input_ids=input_ids).logits
ori_output = torch.argmax(ori_output, dim=-1)
ori_reply = input_text + ' ' + tok.decode(ori_output[0][prompt_len-1:-1], skip_special_tokens=True)
torch.cuda.empty_cache()
ori_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(ori_reply)]
edit_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(edit_reply)]
return ori_reply, edit_reply
def continuous_union_generate(input_text, para_input_text, target_new=None, edit_alg=None):
res1, res2 = continuous_generate(input_text, target_new=target_new, edit_alg=edit_alg)
res3, res4 = continuous_generate(para_input_text, target_new=target_new, edit_alg=edit_alg)
return res1, res2, res3, res4