demo / retrieve.py
AnonymousSub's picture
Upload 4 files
1de9c91 verified
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm import tqdm
import time
import sys
# MODEL_NAME = str(sys.argv[1])
# num_shots = int(sys.argv[2])
# method = str(sys.argv[3]) #['fixed', 'random', 'bm25']
# ADDED K-SHOT SETTING, WHERE K IS VARIABLE
# import openai
import time
# import pandas as pd
import random
random.seed(1)
import csv
import os
import pickle
import json
import nltk
nltk.download('punkt')
nltk.download('stopwords')
from nltk.tokenize import sent_tokenize
from nltk.corpus import stopwords
import string
from langchain.chat_models import AzureChatOpenAI
from langchain.schema import HumanMessage, SystemMessage
from langchain.callbacks import get_openai_callback
from langchain.llms import OpenAI
import tiktoken
import re
from nltk.tokenize import sent_tokenize
from collections import defaultdict
import nltk
from nltk.tokenize import sent_tokenize
from nltk.tokenize import word_tokenize
import numpy as np
# Get the parent directory
# parent_dir = "/home/abnandy/sensei-fs-link"#os.path.abspath(os.path.join(os.getcwd(), os.pardir))
# Add the parent directory to the system path
# sys.path.append(parent_dir)
from utils import AzureModels, write_to_file, read_from_file
# from utils_open import OpenModels
def remove_stopwords_and_punctuation(text):
# Get the list of stopwords
stop_words = set(stopwords.words('english'))
# Remove punctuation from text
text = text.translate(str.maketrans('', '', string.punctuation.replace('_', '').replace('@', '')))
# Split the text into words
words = text.split()
# Remove stopwords
filtered_words = [word for word in words if word.lower() not in stop_words]
# Join the words back into a single string
filtered_text = ' '.join(filtered_words)
return filtered_text
def get_key(list_):
tmp_str = '@cite'
for item in list_:
tmp_str+=item.replace('@cite', '')
return tmp_str
def group_citations(key):
list_ = ["@cite_" + item for item in key.replace("@cite_", "").split("_")]
return ", ".join(list_)
def code_to_extra_info(code_str):
citation_bracket_keys = []
sentence_keys = []
code_lines = code_str.split("\n")
for line in code_lines:
if "citation_bracket[" in line.split("=")[0]:
citation_bracket_keys.append(line.split("=")[0].split('citation_bracket["')[-1].split('"]')[0])
if "sentence[" in line.split("=")[0]:
sentence_keys.append(line.split("=")[0].split('sentence["')[-1].split('"]')[0])
cb_template = "{} are in the same citation bracket (i.e., they are right next to each other) within the section of the Wikipedia Article."
sent_template = "{} are in the same sentence within the section of the Wikipedia Article."
cb_list = [cb_template.format(group_citations(key)) for key in citation_bracket_keys if key.count("_")>1]
sent_list = [sent_template.format(group_citations(key)) for key in sentence_keys if key.count("_")>1]
if len(cb_list) + len(sent_list) == 0:
return ""
return_str = "\n\nNOTE THAT -\n\n" + "\n".join(cb_list) + "\n\n" + "\n".join(sent_list)
return return_str
def get_code_str(related_work, reference_dict):
# print(reference_dict.keys())
citation_bracket_code_lines = []
sentence_code_lines = []
# Tokenize the related work into sentences
sentences = sent_tokenize(related_work)
# Get all citation tags from the reference_dict
citation_tags = list(reference_dict.keys())
for sentence in sentences:
tmp_sentence_list = []
parts = remove_stopwords_and_punctuation(sentence).split(' ')
cb_list = []
str_cb_list = []
# print(parts)
# print(reference_dict.keys())
# print(1/0)
for word in parts:
if word in reference_dict:
cb_list.append(word)
str_cb_list.append('"' + word + '"')
else:
if len(cb_list)>0:
# print(cb_list)
citation_bracket_code_lines.append('citation_bracket["{}"] = {}'.format(get_key(cb_list), str(str_cb_list)))
tmp_sentence_list.append(get_key(cb_list))
cb_list = []
str_cb_list = []
if len(cb_list) > 0:
citation_bracket_code_lines.append('citation_bracket["{}"] = {}'.format(get_key(cb_list), str(str_cb_list)))
tmp_sentence_list.append(get_key(cb_list))
cb_list = []
str_cb_list = []
tmp_values = []
for key in tmp_sentence_list:
tmp_values.append('citation_bracket["{}"]'.format(key))
if len(tmp_values) > 0:
sentence_code_lines.append('sentence["{}"] = {}'.format(get_key(tmp_sentence_list), str(tmp_values)))
return " " + "\n ".join(citation_bracket_code_lines).replace("'", "") + "\n\n " + "\n ".join(sentence_code_lines).replace("'", "")
def get_prompt(list_, i, prompt_template):
gt_summary = list_[i]['related_work'].strip()
inp_intent = list_[i]['abstract'].strip()
input_code_str = " "
input_code_list = []
# print(sent_tokenize(gt_summary))
# print()
# print(1/0)
tmp_list = list_[i]['ref_abstract']
# abstract_list = []
# cite_tags = []
abstract_dict = {}
# write_to_file("dummy.json", tmp_list)
for key in tmp_list:
abstract_dict[key] = tmp_list[key]['abstract'].strip()
for key in abstract_dict:
input_code_list.append('reference_articles["{}"] = "{}"'.format(key, abstract_dict[key]))
input_code_list.append('intent = "{}"'.format(inp_intent))
input_code_str += "\n ".join(input_code_list)
code_str = get_code_str(gt_summary, tmp_list)
prompt = prompt_template.format(input_code_str)
return gt_summary, prompt, code_str
def preprocess_retrieved_out(tmp_keys, out):
new_dict = {}
for key in tmp_keys:
for line in out.split("\n"):
if key in line:
summ_doc = line.split(":", 1)[-1].strip()
new_dict[key] = {"abstract": summ_doc}
print(key)
print(summ_doc)
print()
break
return new_dict
def get_slide(topic, text):
slide_prompt = '''Convert this text into more structured text (in markdown) that can be put into the content of a slide in a presentation (e.g. use bullet points, numbered points, proper layout, etc.). Also, the include the topic "{}" of the slide. -
{}'''
azure_models = AzureModels("gpt4o")
slide_prompt = slide_prompt.format(topic, text)
out_ = azure_models.get_completion(slide_prompt, 100)
time.sleep(2)
return out_
def get_retrieved_results(MODEL_NAME, num_shots, method, train_list, test_list, code=False, organize_out=None):
response_template = ''
instruction_template = ''
final_dict = {}
pred_dict = {}
start_idx = 0
icl_extra_info = ""
test_extra_info = ""
if 'gpt4' in MODEL_NAME:
azure_models = AzureModels(MODEL_NAME)
else:
if code:
instruction_template = '''Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
'''
response_template = '### Response:\n'
else:
response_template = '### Assistant: '
if MODEL_NAME=='gemma2b':
model_id = "google/gemma-2b-it"
elif MODEL_NAME=='gemma7b':
model_id = "google/gemma-7b-it"
elif MODEL_NAME=='mistral7b':
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
elif MODEL_NAME=="llama7b":
model_id = "meta-llama/Llama-2-7b-chat-hf"
elif MODEL_NAME=="llama13b":
model_id = "meta-llama/Llama-2-13b-chat-hf"
elif MODEL_NAME=="llama3":
model_id="meta-llama/Meta-Llama-3-8B-Instruct"
elif MODEL_NAME=="galactica7b":
model_id = "facebook/galactica-6.7b"
open_models = OpenModels(model_id)
prompt_template = '''Given are a set of articles referenced in a Wikipedia Article, and the intent -
Reference Articles:
{}
Intent:
{}
Summarize each reference article (generate in the format "@cite_K : <SUMMARIZED CONTENT CORREPONDING TO @cite_K>", each in a new line, where @cite_K represents each of the following citation/reference tags - {}, given in Reference Articles), given the reference articles as documents, and the intent.{}
{}Answer: '''
if organize_out!=None:
prompt_template = '''Given are a set of articles referenced in a Wikipedia Article, and the intent -
Reference Articles:
{}
Intent:
{}
Generate the wikipedia article section in 100-200 words based on the intent as an intent-based multi-document summary, given the reference articles as documents, and the intent.{}
{}Answer: '''
if code:
prompt_template = '''def main():
# Given is a dictionary of articles that are referenced in a section of the Wikipedia Article, and the intent -
reference_articles = dict()
{}'''
if method == 'bm25':
retrieve_dict = read_from_file("bm25_10_icl_samples_50_holdout_samples.json")
elif method == "gat":
retrieve_dict = read_from_file("gat_20_icl_samples_50_holdout_samples.json")
#len(test_list))):
icl_train_indices = [0,1]
if code:
for i in tqdm(range(start_idx, len(test_list))):#start_idx, len(test_list))):
if len(test_list[i]['ref_abstract']) > 1:
full_icl_prompt = ""
hier_cluster_prompt = "\n def hierarchical_clustering():\n # Hierarchical Clustering of references within a section of the Wikipedia Article, based on the reference articles and the intent\n citation_bracket = {} # This dictionary contains lists as values that shows how references are grouped within the same citation bracket in the section of the Wikipedia Article\n sentence = {} # This dictionary contains lists, where each list contains references in a sentence in the section of the Wikipedia Article\n\n"
if num_shots > 0:
if method == "random":
icl_train_indices = random.sample(holdout_indices, num_shots)#random.sample(np.arange(len(train_list)).tolist())
elif (method == "bm25") or (method == "gat"):
icl_train_indices = [int(retrieve_dict[str(i)][j]) for j in range(num_shots)]
elif method == 'fixed':
icl_train_indices = icl_train_indices[:num_shots]
for enum_idx, icl_train_idx in enumerate(icl_train_indices):
# Fixed ICL Sample
icl_gt_summary, icl_prompt, icl_code_str = get_prompt(train_list, icl_train_idx, prompt_template) # this particular example has 6 citations
# icl_gt_summary_2, icl_prompt_2, icl_code_str_2 = get_prompt(train_list, 85) # this particular example has 12 citations, 4 of which are missing
full_icl_prompt += "##Example {}:\n\n".format(enum_idx + 1) + instruction_template + icl_prompt + hier_cluster_prompt + icl_code_str + "\n\n"
full_icl_prompt += "##Example {}:\n\n".format(num_shots+1)
gt_summary, prompt, code_str = get_prompt(test_list, i, prompt_template)
# full_icl_prompt_2 = "##Example 2:\n\n" + icl_prompt_2 + hier_cluster_prompt + icl_code_str_2
final_prompt = full_icl_prompt + instruction_template + prompt + hier_cluster_prompt + " # only generate the code that comes after this, as if you are on autocomplete mode\n" + response_template
# final_prompt = full_icl_prompt + "\n\n" + full_icl_prompt_2 + "\n\n" + prompt
# final_prompt = full_icl_prompt + "\n\n" + prompt
# print(get_num_inp_tokens(final_prompt))
# print(gt_summary)
# print("---------")
# print(final_prompt)
# print("---------")
# print("GT:")
# print(code_str)
# print("---------")
max_tokens = 500
if 'gpt4' in MODEL_NAME:
out_ = azure_models.get_completion(final_prompt, max_tokens)
time.sleep(2)
else:
out_ = open_models.open_completion(final_prompt, max_tokens, stop_token="##Example {}".format(num_shots + 2))
# print("Predicted:")
# print(out_)
final_dict[i] = out_
return final_dict
# write_to_file(save_filepath, final_dict)
else:
if organize_out==None:
tmp_max_tok_len=1000
else:
tmp_max_tok_len=300
for i in tqdm(range(start_idx, len(test_list))):#len(test_list))):
if len(test_list[i]['ref_abstract']) > 1:
icl_prompt = ""
if num_shots > 0:
if method == "random":
icl_train_indices = random.sample(holdout_indices, num_shots)#random.sample(np.arange(len(train_list)).tolist())
elif method == "bm25":
icl_train_indices = [int(retrieve_dict[str(i)][j]) for j in range(num_shots)]
elif method == 'fixed':
icl_train_indices = icl_train_indices[:num_shots]
for enum_idx, icl_train_idx in enumerate(icl_train_indices):
icl_tmp_list = train_list[icl_train_idx]['ref_abstract']
icl_inp_intent = train_list[icl_train_idx]['abstract']
icl_gt_summary = train_list[icl_train_idx]['related_work']
if organize_out!=None:
icl_code_str = get_code_str(icl_gt_summary, icl_tmp_list)
icl_extra_info = code_to_extra_info(icl_code_str)
icl_abstract_dict = {}
for key in icl_tmp_list:
if organize_out==None:
icl_abstract_dict[key] = icl_tmp_list[key]#['abstract']
else:
icl_abstract_dict[key] = icl_tmp_list[key]['abstract']
icl_abstract_list = [key + " : " + icl_abstract_dict[key] for key in icl_abstract_dict]
icl_paper_abstracts = "\n".join(icl_abstract_list)
icl_prompt += "##Example {}:\n\n".format(enum_idx + 1) + prompt_template.format(icl_paper_abstracts, icl_inp_intent, " ".join(list(icl_tmp_list.keys())), icl_extra_info, response_template) + icl_gt_summary.strip() + "\n\n"
icl_prompt += "##Example {}:\n\n".format(num_shots+1)
gt_summary = test_list[i]['related_work']
inp_intent = test_list[i]['abstract']
if organize_out!=None:
test_code_str = organize_out[str(i)]
test_extra_info = code_to_extra_info(test_code_str)
# print(sent_tokenize(gt_summary))
# print()
# print(1/0)
tmp_list = test_list[i]['ref_abstract']
# abstract_list = []
# cite_tags = []
abstract_dict = {}
for key in tmp_list:
if organize_out==None:
abstract_dict[key] = tmp_list[key]#['abstract']
else:
abstract_dict[key] = tmp_list[key]['abstract']
abstract_list = [key + " : " + abstract_dict[key] for key in abstract_dict]
paper_abstracts = "\n".join(abstract_list)
prompt = prompt_template.format(paper_abstracts, inp_intent, " ".join(list(tmp_list.keys())), test_extra_info, response_template)
# if num_shots == 1:
prompt = icl_prompt + prompt
# print(prompt)
# print("-----------")
if 'gpt4' in MODEL_NAME:
out_ = azure_models.get_completion(prompt, tmp_max_tok_len)
time.sleep(2)
else:
# try:
out_ = open_models.open_completion(prompt, tmp_max_tok_len, temperature=0.7)
if organize_out==None:
test_list[i]["ref_abstract"] = preprocess_retrieved_out(tmp_list, out_)
else:
pred_dict[i] = out_
# return pred_dict
# write_to_file("retrieved_docs.json", test_list)
if organize_out==None:
return test_list
else:
return pred_dict