|
import json |
|
import os |
|
import random |
|
import re |
|
import time |
|
from functools import lru_cache |
|
import torch |
|
|
|
import numpy as np |
|
import openai |
|
try: |
|
import transformers |
|
except ImportError: |
|
import sys |
|
from ditk import logging |
|
logging.warning("not found transformer, please install it using: pip install transformers") |
|
sys.exit(1) |
|
|
|
|
|
def sample_logits(out: torch.Tensor, temperature: float = 1.0, top_p: float = 0.8) -> int: |
|
|
|
probs = torch.softmax(out, dim=-1).cpu().numpy() |
|
sorted_probs = np.sort(probs)[::-1] |
|
cumulative_probs = np.cumsum(sorted_probs) |
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) |
|
probs[probs < cutoff] = 0 |
|
if temperature != 1.0: |
|
probs = probs.pow(1.0 / temperature) |
|
probs = probs / np.sum(probs) |
|
out = np.random.choice(a=len(probs), p=probs) |
|
return out |
|
|
|
|
|
def calc_rwkv( |
|
model: transformers.RwkvForCausalLM, |
|
tokenizer: transformers.AutoTokenizer, |
|
prompt: str, |
|
max_len: int = 10 |
|
) -> str: |
|
|
|
orig_len = len(prompt) |
|
inputs = tokenizer(prompt, return_tensors="pt").to('cuda') |
|
outputs = model(**inputs, labels=inputs["input_ids"]) |
|
out, state = outputs.logits, outputs.state |
|
|
|
with torch.no_grad(): |
|
for i in range(max_len): |
|
token = sample_logits(out[0, -1]) |
|
tmp = tokenizer.decode([token]) |
|
prompt = prompt + tmp |
|
inputs = tokenizer(prompt, return_tensors="pt").to('cuda') |
|
outputs = model(**inputs, labels=inputs["input_ids"]) |
|
out, state = outputs.logits, outputs.state |
|
return prompt[orig_len:] |
|
|
|
|
|
def calc_internlm(model, tokenizer, prompt: str, args): |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
for k, v in inputs.items(): |
|
inputs[k] = v.cuda() |
|
gen_kwargs = { |
|
"max_length": args.max_tokens, |
|
"top_p": args.top_p, |
|
"temperature": args.temperature, |
|
"do_sample": True, |
|
"repetition_penalty": args.frequency_penalty |
|
} |
|
output = model.generate(**inputs, **gen_kwargs) |
|
output = tokenizer.decode(output) |
|
return output |
|
|
|
|
|
def load_data(args: dict) -> tuple: |
|
|
|
random.seed(args.seed) |
|
data_root = 'dizoo/tabmwp/data' |
|
|
|
if not os.path.exists(data_root): |
|
os.mkdir(data_root) |
|
|
|
if not os.path.exists(os.path.join(data_root, f'problems_train.json')): |
|
os.system( |
|
f'wget https://opendilab.net/download/DI-zoo/tabmwp/problems_train.json -O ' + |
|
os.path.join(data_root, f'problems_train.json') + ' --no-check-certificate' |
|
) |
|
problems = json.load(open(os.path.join(data_root, f'problems_train.json'))) |
|
|
|
pids = list(problems.keys()) |
|
samples = random.sample(pids, args.train_number + args.cand_number) |
|
train_pids = samples[:args.train_number] |
|
cand_pids = samples[args.train_number:] |
|
return problems, cand_pids, train_pids |
|
|
|
|
|
def get_gpt3_output(prompt: str, args: dict) -> str: |
|
return call_gpt3( |
|
args.engine, prompt, args.temperature, args.max_tokens, args.top_p, args.frequency_penalty, |
|
args.presence_penalty |
|
) |
|
|
|
|
|
@lru_cache(maxsize=10000) |
|
def call_gpt3( |
|
engine: str, prompt: str, temperature: float, max_tokens: int, top_p: float, frequency_penalty: float, |
|
presence_penalty: float |
|
) -> str: |
|
patience = 100 |
|
while True: |
|
try: |
|
response = openai.Completion.create( |
|
engine=engine, |
|
prompt=prompt, |
|
temperature=temperature, |
|
max_tokens=max_tokens, |
|
top_p=top_p, |
|
frequency_penalty=frequency_penalty, |
|
presence_penalty=presence_penalty, |
|
stop=["\n"] |
|
) |
|
output = response["choices"][0]["text"].strip() |
|
break |
|
except Exception: |
|
patience -= 1 |
|
if not patience: |
|
print("!!! running out of patience waiting for OpenAI") |
|
else: |
|
time.sleep(0.1) |
|
return output |
|
|
|
|
|
def get_table_text(problem: dict) -> str: |
|
table = problem['table'] |
|
title = problem['table_title'] |
|
if title and len(title) > 0: |
|
table = f"[TITLE]: {title}\n{table}" |
|
return table |
|
|
|
|
|
def get_question_text(problem: dict, option_inds: list) -> str: |
|
question = problem['question'] |
|
|
|
unit = problem['unit'] |
|
if unit and len(unit) > 0: |
|
question = f"{question} (Unit: {unit})" |
|
|
|
choices = problem['choices'] |
|
if choices and len(choices) > 0: |
|
choice_list = [] |
|
for i, c in enumerate(choices): |
|
choice_list.append("({}) {}".format(option_inds[i], c)) |
|
options = " ".join(choice_list) |
|
question = f"{question}\nOptions: {options}" |
|
|
|
return question |
|
|
|
|
|
def get_answer(problem: dict) -> str: |
|
return problem['answer'] |
|
|
|
|
|
def get_solution_text(problem: dict) -> str: |
|
|
|
solution = problem['solution'].replace("\n", "\\n") |
|
return solution |
|
|
|
|
|
def create_one_example( |
|
format: str, table: str, question: str, answer: str, solution: str, test_example: bool = True |
|
) -> str: |
|
|
|
input_format, output_format = format.split("-") |
|
|
|
elements = { |
|
"Q": f"Question: {question}", |
|
"T": f"Table: {table}", |
|
"S": f"Solution: {solution}", |
|
"A": f"Answer: The answer is {answer}.", |
|
"AS": f"Answer: The answer is {answer}. BECAUSE: {solution}", |
|
"SA": f"Answer: {solution} The answer is {answer}." |
|
} |
|
|
|
|
|
input = "\n".join(elements[label] for label in input_format) |
|
|
|
|
|
if test_example: |
|
output = "Answer:" |
|
else: |
|
output = elements[output_format] |
|
|
|
|
|
text = input + "\n" + output |
|
text = text.replace(" ", " ").strip() |
|
|
|
return text |
|
|
|
|
|
def build_prompt(problems: list, shot_pids: list, test_pid: int, args: dict) -> str: |
|
|
|
examples = [] |
|
pids = shot_pids + [test_pid] |
|
|
|
|
|
for pid in pids: |
|
problem = problems[pid] |
|
table = get_table_text(problem) |
|
question = get_question_text(problem, args.option_inds) |
|
answer = get_answer(problem) |
|
solution = get_solution_text(problems[pid]) |
|
|
|
if pid == test_pid: |
|
assert pid not in shot_pids |
|
example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=True) |
|
else: |
|
example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=False) |
|
|
|
examples.append(example) |
|
|
|
|
|
prompt_input = '\n\n'.join(examples) |
|
|
|
return prompt_input |
|
|
|
|
|
def extract_prediction(output: str, options: list, option_inds: list) -> str: |
|
idx = output.find('\n') |
|
if idx > 0: |
|
output = output[:idx] |
|
idx = output.find('=') |
|
if idx > 0: |
|
output = output[idx + 1:].strip() |
|
|
|
output = re.sub(r"\$?\\frac\{([\d\.\,\-]+)\}\{([\d\.\,]+)\}\$?", r"\1/\2", output) |
|
|
|
output = re.sub(r"(?<![AP]\.M)\.$", "", output) |
|
output = re.sub(r"(?<=\d)[\=](?=[\-\$\d])", " = ", output) |
|
output = re.sub(r"\u2212", "-", output) |
|
|
|
|
|
if options: |
|
patterns = [ |
|
r'^\(([A-Za-z])\)$', |
|
r'^([A-Za-z])$', |
|
r'^([A-Za-z]). ', |
|
r'[Th]he answer is ([A-Z])', |
|
r'^\(([A-Za-z])\) [\s\S]+$', |
|
r'[Th]he answer is \(([A-Za-z])\) [\s\S]+$', |
|
] |
|
|
|
|
|
for p in patterns: |
|
pattern = re.compile(p) |
|
res = pattern.findall(output) |
|
if len(res) > 0: |
|
pred = res[0].upper() |
|
if pred in option_inds: |
|
ind = option_inds.index(pred) |
|
if ind >= len(options): |
|
ind = random.choice(range(len(options))) |
|
predition = options[ind] |
|
return predition |
|
|
|
|
|
scores = [score_string_similarity(x, output) for x in options] |
|
max_idx = int(np.argmax(scores)) |
|
predition = options[max_idx] |
|
return predition |
|
|
|
else: |
|
|
|
patterns = [ |
|
|
|
|
|
r'[Th]he answer is ([\s\S]+)$', |
|
r'[Th]he table shows that ([\d\$\.\,\/\:]+) ', |
|
r' = ([\d\$\.\,\/\:]+)', |
|
r'(?<= be| is) ([\-\d\$\.\,\/\:]{0,}[\d]+)', |
|
r'(?<= are| was) ([\-\d\$\.\,\/\:]{0,}[\d]+)', |
|
r'(?<= were) ([\-\d\$\.\,\/\:]{0,}[\d]+)', |
|
r' ([\d\$\.\,\/\:]+ [AP]\.M\.)', |
|
r'([\-\d\$\.\,\/\:]{0,}[\d]+)', |
|
] |
|
|
|
for p in patterns: |
|
pattern = re.compile(p) |
|
res = pattern.findall(output) |
|
if len(res) > 0: |
|
predition = res[-1].strip() |
|
if predition.endswith(".") and ".M." not in predition: |
|
predition = predition[:-1] |
|
return predition |
|
|
|
return output |
|
|
|
|
|
def normalize_answer(text: str, unit: str) -> str: |
|
|
|
|
|
text = re.sub("^[\$]", "", text) |
|
text = re.sub("[\,\.\,\/]$", "", text) |
|
result = re.match("^[-+]?[\d,./]+$", text) |
|
|
|
if result is not None: |
|
|
|
text = text.replace(",", "") |
|
result = re.match("[-+]?\d+$", text) |
|
try: |
|
if result is not None: |
|
number = int(text) |
|
elif "/" in text: |
|
nums = text.split("/") |
|
number = round(float(nums[0]) / float(nums[1]), 3) |
|
else: |
|
number = round(float(text), 3) |
|
number = str(number) |
|
number = re.sub(r"\.[0]+$", "", number) |
|
return number |
|
except: |
|
return text |
|
else: |
|
|
|
if unit: |
|
text = text.replace(unit, "").strip() |
|
return text |
|
|
|
|
|
def score_string_similarity(str1: str, str2: str) -> float: |
|
if str1 == str2: |
|
return 2.0 |
|
if " " in str1 or " " in str2: |
|
str1_split = str1.split(" ") |
|
str2_split = str2.split(" ") |
|
overlap = list(set(str1_split) & set(str2_split)) |
|
return len(overlap) / max(len(str1_split), len(str2_split)) |
|
else: |
|
if str1 == str2: |
|
return 1.0 |
|
else: |
|
return 0.0 |
|
|
|
|
|
def create_example_from_pid(pid: int, problems: list, args: dict, test: bool = False) -> str: |
|
problem = problems[pid] |
|
table = get_table_text(problem) |
|
question = get_question_text(problem, args.option_inds) |
|
answer = get_answer(problem) |
|
solution = get_solution_text(problems[pid]) |
|
|
|
if test: |
|
example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=True) |
|
else: |
|
example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=False) |
|
|
|
return example |
|
|