TIGERScore / utils.py
DongfuJiang's picture
update
301c810
raw
history blame
3.15 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
from string import Template
import torch
FINETUNE_INST = "You are evaluating errors in a model-generated output for a(an) ${task} task."
FINETUNE_INPUT = """\
Task instruction: ${generation_instruction}
Source: ${input_context}
Model-generated Output: ${hypothesis_output}
Based on the given task instruction and source, identify errors in this model-generated output.
For each error you give in the response, please also elaborate the following information:
- error location (the words that are wrong in the output)
- error aspect it belongs to.
- explanation why it's an error, and the correction suggestions.
- severity of the error ("Major" or "Minor").
- reduction of score (between 0.5 and 5 given the severity of the error)
Your evaluation output:
"""
TIGERScore_model_map = {
"7b": "TIGER-Lab/TIGERScore-7B-V1.0",
"13b": "TIGER-Lab/TIGERScore-13B-V1.0",
}
tigerscore_model = None
tigerscore_tokenizer = None
tasks = [
"translation",
"summarization",
"data2text",
"mathQA",
"long-form QA",
"instruction-following",
]
def load_tigerscore(model_size):
assert model_size in TIGERScore_model_map
model_name = TIGERScore_model_map[model_size]
global tigerscore_model, tigerscore_tokenizer
tigerscore_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tigerscore_tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=True
)
def generate(task, input_context, generation_instruction, hypo_output, **generate_kwargs):
inst_part = Template(FINETUNE_INST)
inst_part = inst_part.substitute(task=task)
input_part = Template(FINETUNE_INPUT)
input_part = input_part.substitute(
generation_instruction=generation_instruction,
input_context=input_context,
hypothesis_output=hypo_output
)
prompt = (inst_part + "\n" + input_part).strip("\n ") + "\n"
encodings = tigerscore_tokenizer(prompt, return_tensors="pt")
input_ids = encodings["input_ids"].to(tigerscore_model.device)
attention_mask = encodings["attention_mask"].to(tigerscore_model.device)
gen_params = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": 512,
"do_sample": True,
"top_k": 1,
"num_return_sequences": 1,
}
gen_params.update(generate_kwargs)
output = tigerscore_model.generate(**gen_params)
output = tigerscore_tokenizer.decode(output[0][len(input_ids[0]):], skip_special_tokens=True)
return output
if __name__ == "__main__":
task = "translation"
input_context = "Der künftige EM-Cheforganisator Philipp Lahm soll laut Grindel im DFB-Präsidium mitarbeiten."
generation_instruction = "Translate the following text from German to English."
hypo_output = "According to Grindel, the future head of the European Championships, Philipp Lahm, is to participate in the DFB Presidency."
output = generate(task, input_context, generation_instruction, hypo_output)
print(output)