import json |
import os |
import random |
import re |
import time |
from functools import lru_cache |
import torch |
import numpy as np |
import openai |
try: |
import transformers |
except ImportError: |
import sys |
from ditk import logging |
logging.warning("not found transformer, please install it using: pip install transformers") |
sys.exit(1) |
def sample_logits(out: torch.Tensor, temperature: float = 1.0, top_p: float = 0.8) -> int: |
probs = torch.softmax(out, dim=-1).cpu().numpy() |
sorted_probs = np.sort(probs)[::-1] |
cumulative_probs = np.cumsum(sorted_probs) |
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) |
probs[probs < cutoff] = 0 |
if temperature != 1.0: |
probs = probs.pow(1.0 / temperature) |
probs = probs / np.sum(probs) |
out = np.random.choice(a=len(probs), p=probs) |
return out |
def calc_rwkv( |
model: transformers.RwkvForCausalLM, |
tokenizer: transformers.AutoTokenizer, |
prompt: str, |
max_len: int = 10 |
) -> str: |
orig_len = len(prompt) |
inputs = tokenizer(prompt, return_tensors="pt").to('cuda') |
outputs = model(**inputs, labels=inputs["input_ids"]) |
out, state = outputs.logits, outputs.state |
with torch.no_grad(): |
for i in range(max_len): |
token = sample_logits(out[0, -1]) |
tmp = tokenizer.decode([token]) |
prompt = prompt + tmp |
inputs = tokenizer(prompt, return_tensors="pt").to('cuda') |
outputs = model(**inputs, labels=inputs["input_ids"]) |
out, state = outputs.logits, outputs.state |
return prompt[orig_len:] |
def calc_internlm(model, tokenizer, prompt: str, args): |
inputs = tokenizer(prompt, return_tensors="pt") |
for k, v in inputs.items(): |
inputs[k] = v.cuda() |
gen_kwargs = { |
"max_length": args.max_tokens, |
"top_p": args.top_p, |
"temperature": args.temperature, |
"do_sample": True, |
"repetition_penalty": args.frequency_penalty |
} |
output = model.generate(**inputs, **gen_kwargs) |
output = tokenizer.decode(output) |
return output |
def load_data(args: dict) -> tuple: |
random.seed(args.seed) |
data_root = 'dizoo/tabmwp/data' |
if not os.path.exists(data_root): |
os.mkdir(data_root) |
if not os.path.exists(os.path.join(data_root, f'problems_train.json')): |
os.system( |
f'wget https://opendilab.net/download/DI-zoo/tabmwp/problems_train.json -O ' + |
os.path.join(data_root, f'problems_train.json') + ' --no-check-certificate' |
) |
problems = json.load(open(os.path.join(data_root, f'problems_train.json'))) |
pids = list(problems.keys()) |
samples = random.sample(pids, args.train_number + args.cand_number) |
train_pids = samples[:args.train_number] |
cand_pids = samples[args.train_number:] |
return problems, cand_pids, train_pids |
def get_gpt3_output(prompt: str, args: dict) -> str: |
return call_gpt3( |
args.engine, prompt, args.temperature, args.max_tokens, args.top_p, args.frequency_penalty, |
args.presence_penalty |
) |
@lru_cache(maxsize=10000) |
def call_gpt3( |
engine: str, prompt: str, temperature: float, max_tokens: int, top_p: float, frequency_penalty: float, |
presence_penalty: float |
) -> str: |
patience = 100 |
while True: |
try: |
response = openai.Completion.create( |
engine=engine, |
prompt=prompt, |
temperature=temperature, |
max_tokens=max_tokens, |
top_p=top_p, |
frequency_penalty=frequency_penalty, |
presence_penalty=presence_penalty, |
stop=["\n"] |
) |
output = response["choices"][0]["text"].strip() |
break |
except Exception: |
patience -= 1 |
if not patience: |
print("!!! running out of patience waiting for OpenAI") |
else: |
time.sleep(0.1) |
return output |
def get_table_text(problem: dict) -> str: |
table = problem['table'] |
title = problem['table_title'] |
if title and len(title) > 0: |
table = f"[TITLE]: {title}\n{table}" |
return table |
def get_question_text(problem: dict, option_inds: list) -> str: |
question = problem['question'] |
unit = problem['unit'] |
if unit and len(unit) > 0: |
question = f"{question} (Unit: {unit})" |
choices = problem['choices'] |
if choices and len(choices) > 0: |
choice_list = [] |
for i, c in enumerate(choices): |
choice_list.append("({}) {}".format(option_inds[i], c)) |
options = " ".join(choice_list) |
question = f"{question}\nOptions: {options}" |
return question |
def get_answer(problem: dict) -> str: |
return problem['answer'] |
def get_solution_text(problem: dict) -> str: |
solution = problem['solution'].replace("\n", "\\n") |
return solution |
def create_one_example( |
format: str, table: str, question: str, answer: str, solution: str, test_example: bool = True |
) -> str: |
input_format, output_format = format.split("-") |
elements = { |
"Q": f"Question: {question}", |
"T": f"Table: {table}", |
"S": f"Solution: {solution}", |
"A": f"Answer: The answer is {answer}.", |
"AS": f"Answer: The answer is {answer}. BECAUSE: {solution}", |
"SA": f"Answer: {solution} The answer is {answer}." |
} |
input = "\n".join(elements[label] for label in input_format) |
if test_example: |
output = "Answer:" |
else: |
output = elements[output_format] |
text = input + "\n" + output |
text = text.replace(" ", " ").strip() |
return text |
def build_prompt(problems: list, shot_pids: list, test_pid: int, args: dict) -> str: |
examples = [] |
pids = shot_pids + [test_pid] |
for pid in pids: |
problem = problems[pid] |
table = get_table_text(problem) |
question = get_question_text(problem, args.option_inds) |
answer = get_answer(problem) |
solution = get_solution_text(problems[pid]) |
if pid == test_pid: |
assert pid not in shot_pids |
example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=True) |
else: |
example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=False) |
examples.append(example) |
prompt_input = '\n\n'.join(examples) |
return prompt_input |
def extract_prediction(output: str, options: list, option_inds: list) -> str: |
idx = output.find('\n') |
if idx > 0: |
output = output[:idx] |
idx = output.find('=') |
if idx > 0: |
output = output[idx + 1:].strip() |
output = re.sub(r"\$?\\frac\{([\d\.\,\-]+)\}\{([\d\.\,]+)\}\$?", r"\1/\2", output) |
output = re.sub(r"(?<![AP]\.M)\.$", "", output) |
output = re.sub(r"(?<=\d)[\=](?=[\-\$\d])", " = ", output) |
output = re.sub(r"\u2212", "-", output) |
if options: |
patterns = [ |
r'^\(([A-Za-z])\)$', |
r'^([A-Za-z])$', |
r'^([A-Za-z]). ', |
r'[Th]he answer is ([A-Z])', |
r'^\(([A-Za-z])\) [\s\S]+$', |
r'[Th]he answer is \(([A-Za-z])\) [\s\S]+$', |
] |
for p in patterns: |
pattern = re.compile(p) |
res = pattern.findall(output) |
if len(res) > 0: |
pred = res[0].upper() |
if pred in option_inds: |
ind = option_inds.index(pred) |
if ind >= len(options): |
ind = random.choice(range(len(options))) |
predition = options[ind] |
return predition |
scores = [score_string_similarity(x, output) for x in options] |
max_idx = int(np.argmax(scores)) |
predition = options[max_idx] |
return predition |
else: |
patterns = [ |
r'[Th]he answer is ([\s\S]+)$', |
r'[Th]he table shows that ([\d\$\.\,\/\:]+) ', |
r' = ([\d\$\.\,\/\:]+)', |
r'(?<= be| is) ([\-\d\$\.\,\/\:]{0,}[\d]+)', |
r'(?<= are| was) ([\-\d\$\.\,\/\:]{0,}[\d]+)', |
r'(?<= were) ([\-\d\$\.\,\/\:]{0,}[\d]+)', |
r' ([\d\$\.\,\/\:]+ [AP]\.M\.)', |
r'([\-\d\$\.\,\/\:]{0,}[\d]+)', |
] |
for p in patterns: |
pattern = re.compile(p) |
res = pattern.findall(output) |
if len(res) > 0: |
predition = res[-1].strip() |
if predition.endswith(".") and ".M." not in predition: |
predition = predition[:-1] |
return predition |
return output |
def normalize_answer(text: str, unit: str) -> str: |
text = re.sub("^[\$]", "", text) |
text = re.sub("[\,\.\,\/]$", "", text) |
result = re.match("^[-+]?[\d,./]+$", text) |
if result is not None: |
text = text.replace(",", "") |
result = re.match("[-+]?\d+$", text) |
try: |
if result is not None: |
number = int(text) |
elif "/" in text: |
nums = text.split("/") |
number = round(float(nums[0]) / float(nums[1]), 3) |
else: |
number = round(float(text), 3) |
number = str(number) |
number = re.sub(r"\.[0]+$", "", number) |
return number |
except: |
return text |
else: |
if unit: |
text = text.replace(unit, "").strip() |
return text |
def score_string_similarity(str1: str, str2: str) -> float: |
if str1 == str2: |
return 2.0 |
if " " in str1 or " " in str2: |
str1_split = str1.split(" ") |
str2_split = str2.split(" ") |
overlap = list(set(str1_split) & set(str2_split)) |
return len(overlap) / max(len(str1_split), len(str2_split)) |
else: |
if str1 == str2: |
return 1.0 |
else: |
return 0.0 |
def create_example_from_pid(pid: int, problems: list, args: dict, test: bool = False) -> str: |
problem = problems[pid] |
table = get_table_text(problem) |
question = get_question_text(problem, args.option_inds) |
answer = get_answer(problem) |
solution = get_solution_text(problems[pid]) |
if test: |
example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=True) |
else: |
example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=False) |
return example |