Spaces:
Sleeping
Sleeping
from tenacity import retry, stop_after_attempt, wait_random_exponential | |
from tqdm import tqdm | |
import time | |
import sys | |
# MODEL_NAME = str(sys.argv[1]) | |
# num_shots = int(sys.argv[2]) | |
# method = str(sys.argv[3]) #['fixed', 'random', 'bm25'] | |
# ADDED K-SHOT SETTING, WHERE K IS VARIABLE | |
# import openai | |
import time | |
# import pandas as pd | |
import random | |
random.seed(1) | |
import csv | |
import os | |
import pickle | |
import json | |
import nltk | |
nltk.download('punkt') | |
nltk.download('stopwords') | |
from nltk.tokenize import sent_tokenize | |
from nltk.corpus import stopwords | |
import string | |
from langchain.chat_models import AzureChatOpenAI | |
from langchain.schema import HumanMessage, SystemMessage | |
from langchain.callbacks import get_openai_callback | |
from langchain.llms import OpenAI | |
import tiktoken | |
import re | |
from nltk.tokenize import sent_tokenize | |
from collections import defaultdict | |
import nltk | |
from nltk.tokenize import sent_tokenize | |
from nltk.tokenize import word_tokenize | |
import numpy as np | |
# Get the parent directory | |
# parent_dir = "/home/abnandy/sensei-fs-link"#os.path.abspath(os.path.join(os.getcwd(), os.pardir)) | |
# Add the parent directory to the system path | |
# sys.path.append(parent_dir) | |
from utils import AzureModels, write_to_file, read_from_file | |
# from utils_open import OpenModels | |
def remove_stopwords_and_punctuation(text): | |
# Get the list of stopwords | |
stop_words = set(stopwords.words('english')) | |
# Remove punctuation from text | |
text = text.translate(str.maketrans('', '', string.punctuation.replace('_', '').replace('@', ''))) | |
# Split the text into words | |
words = text.split() | |
# Remove stopwords | |
filtered_words = [word for word in words if word.lower() not in stop_words] | |
# Join the words back into a single string | |
filtered_text = ' '.join(filtered_words) | |
return filtered_text | |
def get_key(list_): | |
tmp_str = '@cite' | |
for item in list_: | |
tmp_str+=item.replace('@cite', '') | |
return tmp_str | |
def group_citations(key): | |
list_ = ["@cite_" + item for item in key.replace("@cite_", "").split("_")] | |
return ", ".join(list_) | |
def code_to_extra_info(code_str): | |
citation_bracket_keys = [] | |
sentence_keys = [] | |
code_lines = code_str.split("\n") | |
for line in code_lines: | |
if "citation_bracket[" in line.split("=")[0]: | |
citation_bracket_keys.append(line.split("=")[0].split('citation_bracket["')[-1].split('"]')[0]) | |
if "sentence[" in line.split("=")[0]: | |
sentence_keys.append(line.split("=")[0].split('sentence["')[-1].split('"]')[0]) | |
cb_template = "{} are in the same citation bracket (i.e., they are right next to each other) within the section of the Wikipedia Article." | |
sent_template = "{} are in the same sentence within the section of the Wikipedia Article." | |
cb_list = [cb_template.format(group_citations(key)) for key in citation_bracket_keys if key.count("_")>1] | |
sent_list = [sent_template.format(group_citations(key)) for key in sentence_keys if key.count("_")>1] | |
if len(cb_list) + len(sent_list) == 0: | |
return "" | |
return_str = "\n\nNOTE THAT -\n\n" + "\n".join(cb_list) + "\n\n" + "\n".join(sent_list) | |
return return_str | |
def get_code_str(related_work, reference_dict): | |
# print(reference_dict.keys()) | |
citation_bracket_code_lines = [] | |
sentence_code_lines = [] | |
# Tokenize the related work into sentences | |
sentences = sent_tokenize(related_work) | |
# Get all citation tags from the reference_dict | |
citation_tags = list(reference_dict.keys()) | |
for sentence in sentences: | |
tmp_sentence_list = [] | |
parts = remove_stopwords_and_punctuation(sentence).split(' ') | |
cb_list = [] | |
str_cb_list = [] | |
# print(parts) | |
# print(reference_dict.keys()) | |
# print(1/0) | |
for word in parts: | |
if word in reference_dict: | |
cb_list.append(word) | |
str_cb_list.append('"' + word + '"') | |
else: | |
if len(cb_list)>0: | |
# print(cb_list) | |
citation_bracket_code_lines.append('citation_bracket["{}"] = {}'.format(get_key(cb_list), str(str_cb_list))) | |
tmp_sentence_list.append(get_key(cb_list)) | |
cb_list = [] | |
str_cb_list = [] | |
if len(cb_list) > 0: | |
citation_bracket_code_lines.append('citation_bracket["{}"] = {}'.format(get_key(cb_list), str(str_cb_list))) | |
tmp_sentence_list.append(get_key(cb_list)) | |
cb_list = [] | |
str_cb_list = [] | |
tmp_values = [] | |
for key in tmp_sentence_list: | |
tmp_values.append('citation_bracket["{}"]'.format(key)) | |
if len(tmp_values) > 0: | |
sentence_code_lines.append('sentence["{}"] = {}'.format(get_key(tmp_sentence_list), str(tmp_values))) | |
return " " + "\n ".join(citation_bracket_code_lines).replace("'", "") + "\n\n " + "\n ".join(sentence_code_lines).replace("'", "") | |
def get_prompt(list_, i, prompt_template): | |
gt_summary = list_[i]['related_work'].strip() | |
inp_intent = list_[i]['abstract'].strip() | |
input_code_str = " " | |
input_code_list = [] | |
# print(sent_tokenize(gt_summary)) | |
# print() | |
# print(1/0) | |
tmp_list = list_[i]['ref_abstract'] | |
# abstract_list = [] | |
# cite_tags = [] | |
abstract_dict = {} | |
# write_to_file("dummy.json", tmp_list) | |
for key in tmp_list: | |
abstract_dict[key] = tmp_list[key]['abstract'].strip() | |
for key in abstract_dict: | |
input_code_list.append('reference_articles["{}"] = "{}"'.format(key, abstract_dict[key])) | |
input_code_list.append('intent = "{}"'.format(inp_intent)) | |
input_code_str += "\n ".join(input_code_list) | |
code_str = get_code_str(gt_summary, tmp_list) | |
prompt = prompt_template.format(input_code_str) | |
return gt_summary, prompt, code_str | |
def preprocess_retrieved_out(tmp_keys, out): | |
new_dict = {} | |
for key in tmp_keys: | |
for line in out.split("\n"): | |
if key in line: | |
summ_doc = line.split(":", 1)[-1].strip() | |
new_dict[key] = {"abstract": summ_doc} | |
print(key) | |
print(summ_doc) | |
print() | |
break | |
return new_dict | |
def get_slide(topic, text): | |
slide_prompt = '''Convert this text into more structured text (in markdown) that can be put into the content of a slide in a presentation (e.g. use bullet points, numbered points, proper layout, etc.). Also, the include the topic "{}" of the slide. - | |
{}''' | |
azure_models = AzureModels("gpt4o") | |
slide_prompt = slide_prompt.format(topic, text) | |
out_ = azure_models.get_completion(slide_prompt, 100) | |
time.sleep(2) | |
return out_ | |
def get_retrieved_results(MODEL_NAME, num_shots, method, train_list, test_list, code=False, organize_out=None): | |
response_template = '' | |
instruction_template = '' | |
final_dict = {} | |
pred_dict = {} | |
start_idx = 0 | |
icl_extra_info = "" | |
test_extra_info = "" | |
if 'gpt4' in MODEL_NAME: | |
azure_models = AzureModels(MODEL_NAME) | |
else: | |
if code: | |
instruction_template = '''Below is an instruction that describes a task. Write a response that appropriately completes the request. | |
### Instruction: | |
''' | |
response_template = '### Response:\n' | |
else: | |
response_template = '### Assistant: ' | |
if MODEL_NAME=='gemma2b': | |
model_id = "google/gemma-2b-it" | |
elif MODEL_NAME=='gemma7b': | |
model_id = "google/gemma-7b-it" | |
elif MODEL_NAME=='mistral7b': | |
model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
elif MODEL_NAME=="llama7b": | |
model_id = "meta-llama/Llama-2-7b-chat-hf" | |
elif MODEL_NAME=="llama13b": | |
model_id = "meta-llama/Llama-2-13b-chat-hf" | |
elif MODEL_NAME=="llama3": | |
model_id="meta-llama/Meta-Llama-3-8B-Instruct" | |
elif MODEL_NAME=="galactica7b": | |
model_id = "facebook/galactica-6.7b" | |
open_models = OpenModels(model_id) | |
prompt_template = '''Given are a set of articles referenced in a Wikipedia Article, and the intent - | |
Reference Articles: | |
{} | |
Intent: | |
{} | |
Summarize each reference article (generate in the format "@cite_K : <SUMMARIZED CONTENT CORREPONDING TO @cite_K>", each in a new line, where @cite_K represents each of the following citation/reference tags - {}, given in Reference Articles), given the reference articles as documents, and the intent.{} | |
{}Answer: ''' | |
if organize_out!=None: | |
prompt_template = '''Given are a set of articles referenced in a Wikipedia Article, and the intent - | |
Reference Articles: | |
{} | |
Intent: | |
{} | |
Generate the wikipedia article section in 100-200 words based on the intent as an intent-based multi-document summary, given the reference articles as documents, and the intent.{} | |
{}Answer: ''' | |
if code: | |
prompt_template = '''def main(): | |
# Given is a dictionary of articles that are referenced in a section of the Wikipedia Article, and the intent - | |
reference_articles = dict() | |
{}''' | |
if method == 'bm25': | |
retrieve_dict = read_from_file("bm25_10_icl_samples_50_holdout_samples.json") | |
elif method == "gat": | |
retrieve_dict = read_from_file("gat_20_icl_samples_50_holdout_samples.json") | |
#len(test_list))): | |
icl_train_indices = [0,1] | |
if code: | |
for i in tqdm(range(start_idx, len(test_list))):#start_idx, len(test_list))): | |
if len(test_list[i]['ref_abstract']) > 1: | |
full_icl_prompt = "" | |
hier_cluster_prompt = "\n def hierarchical_clustering():\n # Hierarchical Clustering of references within a section of the Wikipedia Article, based on the reference articles and the intent\n citation_bracket = {} # This dictionary contains lists as values that shows how references are grouped within the same citation bracket in the section of the Wikipedia Article\n sentence = {} # This dictionary contains lists, where each list contains references in a sentence in the section of the Wikipedia Article\n\n" | |
if num_shots > 0: | |
if method == "random": | |
icl_train_indices = random.sample(holdout_indices, num_shots)#random.sample(np.arange(len(train_list)).tolist()) | |
elif (method == "bm25") or (method == "gat"): | |
icl_train_indices = [int(retrieve_dict[str(i)][j]) for j in range(num_shots)] | |
elif method == 'fixed': | |
icl_train_indices = icl_train_indices[:num_shots] | |
for enum_idx, icl_train_idx in enumerate(icl_train_indices): | |
# Fixed ICL Sample | |
icl_gt_summary, icl_prompt, icl_code_str = get_prompt(train_list, icl_train_idx, prompt_template) # this particular example has 6 citations | |
# icl_gt_summary_2, icl_prompt_2, icl_code_str_2 = get_prompt(train_list, 85) # this particular example has 12 citations, 4 of which are missing | |
full_icl_prompt += "##Example {}:\n\n".format(enum_idx + 1) + instruction_template + icl_prompt + hier_cluster_prompt + icl_code_str + "\n\n" | |
full_icl_prompt += "##Example {}:\n\n".format(num_shots+1) | |
gt_summary, prompt, code_str = get_prompt(test_list, i, prompt_template) | |
# full_icl_prompt_2 = "##Example 2:\n\n" + icl_prompt_2 + hier_cluster_prompt + icl_code_str_2 | |
final_prompt = full_icl_prompt + instruction_template + prompt + hier_cluster_prompt + " # only generate the code that comes after this, as if you are on autocomplete mode\n" + response_template | |
# final_prompt = full_icl_prompt + "\n\n" + full_icl_prompt_2 + "\n\n" + prompt | |
# final_prompt = full_icl_prompt + "\n\n" + prompt | |
# print(get_num_inp_tokens(final_prompt)) | |
# print(gt_summary) | |
# print("---------") | |
# print(final_prompt) | |
# print("---------") | |
# print("GT:") | |
# print(code_str) | |
# print("---------") | |
max_tokens = 500 | |
if 'gpt4' in MODEL_NAME: | |
out_ = azure_models.get_completion(final_prompt, max_tokens) | |
time.sleep(2) | |
else: | |
out_ = open_models.open_completion(final_prompt, max_tokens, stop_token="##Example {}".format(num_shots + 2)) | |
# print("Predicted:") | |
# print(out_) | |
final_dict[i] = out_ | |
return final_dict | |
# write_to_file(save_filepath, final_dict) | |
else: | |
if organize_out==None: | |
tmp_max_tok_len=1000 | |
else: | |
tmp_max_tok_len=300 | |
for i in tqdm(range(start_idx, len(test_list))):#len(test_list))): | |
if len(test_list[i]['ref_abstract']) > 1: | |
icl_prompt = "" | |
if num_shots > 0: | |
if method == "random": | |
icl_train_indices = random.sample(holdout_indices, num_shots)#random.sample(np.arange(len(train_list)).tolist()) | |
elif method == "bm25": | |
icl_train_indices = [int(retrieve_dict[str(i)][j]) for j in range(num_shots)] | |
elif method == 'fixed': | |
icl_train_indices = icl_train_indices[:num_shots] | |
for enum_idx, icl_train_idx in enumerate(icl_train_indices): | |
icl_tmp_list = train_list[icl_train_idx]['ref_abstract'] | |
icl_inp_intent = train_list[icl_train_idx]['abstract'] | |
icl_gt_summary = train_list[icl_train_idx]['related_work'] | |
if organize_out!=None: | |
icl_code_str = get_code_str(icl_gt_summary, icl_tmp_list) | |
icl_extra_info = code_to_extra_info(icl_code_str) | |
icl_abstract_dict = {} | |
for key in icl_tmp_list: | |
if organize_out==None: | |
icl_abstract_dict[key] = icl_tmp_list[key]#['abstract'] | |
else: | |
icl_abstract_dict[key] = icl_tmp_list[key]['abstract'] | |
icl_abstract_list = [key + " : " + icl_abstract_dict[key] for key in icl_abstract_dict] | |
icl_paper_abstracts = "\n".join(icl_abstract_list) | |
icl_prompt += "##Example {}:\n\n".format(enum_idx + 1) + prompt_template.format(icl_paper_abstracts, icl_inp_intent, " ".join(list(icl_tmp_list.keys())), icl_extra_info, response_template) + icl_gt_summary.strip() + "\n\n" | |
icl_prompt += "##Example {}:\n\n".format(num_shots+1) | |
gt_summary = test_list[i]['related_work'] | |
inp_intent = test_list[i]['abstract'] | |
if organize_out!=None: | |
test_code_str = organize_out[str(i)] | |
test_extra_info = code_to_extra_info(test_code_str) | |
# print(sent_tokenize(gt_summary)) | |
# print() | |
# print(1/0) | |
tmp_list = test_list[i]['ref_abstract'] | |
# abstract_list = [] | |
# cite_tags = [] | |
abstract_dict = {} | |
for key in tmp_list: | |
if organize_out==None: | |
abstract_dict[key] = tmp_list[key]#['abstract'] | |
else: | |
abstract_dict[key] = tmp_list[key]['abstract'] | |
abstract_list = [key + " : " + abstract_dict[key] for key in abstract_dict] | |
paper_abstracts = "\n".join(abstract_list) | |
prompt = prompt_template.format(paper_abstracts, inp_intent, " ".join(list(tmp_list.keys())), test_extra_info, response_template) | |
# if num_shots == 1: | |
prompt = icl_prompt + prompt | |
# print(prompt) | |
# print("-----------") | |
if 'gpt4' in MODEL_NAME: | |
out_ = azure_models.get_completion(prompt, tmp_max_tok_len) | |
time.sleep(2) | |
else: | |
# try: | |
out_ = open_models.open_completion(prompt, tmp_max_tok_len, temperature=0.7) | |
if organize_out==None: | |
test_list[i]["ref_abstract"] = preprocess_retrieved_out(tmp_list, out_) | |
else: | |
pred_dict[i] = out_ | |
# return pred_dict | |
# write_to_file("retrieved_docs.json", test_list) | |
if organize_out==None: | |
return test_list | |
else: | |
return pred_dict |