TravelPlannerLeaderboard / tools /planner /planner_with_human_annotated_info.py
hsaest's picture
Upload folder using huggingface_hub
9be4956 verified
raw
history blame
No virus
5.93 kB
import os
import re
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
sys.path.append('/home/xj/toolAugEnv/code/toolConstraint')
# print(sys.path)
os.chdir(os.path.dirname(os.path.abspath(__file__)))
from agents.prompts import planner_agent_prompt, cot_planner_agent_prompt, react_planner_agent_prompt,react_reflect_planner_agent_prompt,reflect_prompt
# from annotation.src.utils import get_valid_name_city,extract_before_parenthesis, extract_numbers_from_filenames
import json
import time
from langchain.callbacks import get_openai_callback
from tqdm import tqdm
from tools.planner.apis import Planner, ReactPlanner, ReactReflectPlanner
import openai
os.environ["http_proxy"] = "http://127.0.0.1:7890"
os.environ["https_proxy"] = "http://127.0.0.1:7890"
def load_line_json_data(filename):
data = []
with open(filename, 'r', encoding='utf-8') as f:
for line in f.read().strip().split('\n'):
unit = json.loads(line)
data.append(unit)
return data
def extract_numbers_from_filenames(directory):
# Define the pattern to match files
pattern = r'annotation_(\d+).json'
# List all files in the directory
files = os.listdir(directory)
# Extract numbers from filenames that match the pattern
numbers = [int(re.search(pattern, file).group(1)) for file in files if re.match(pattern, file)]
return numbers
def catch_openai_api_error():
error = sys.exc_info()[0]
if error == openai.error.APIConnectionError:
print("APIConnectionError")
elif error == openai.error.RateLimitError:
print("RateLimitError")
time.sleep(60)
elif error == openai.error.APIError:
print("APIError")
elif error == openai.error.AuthenticationError:
print("AuthenticationError")
else:
print("API error:", error)
# if __name__ == "__main__":
# user_name = 'zk'
# directory = '../../data/annotation/{}'.format(user_name)
# query_data_list = load_line_json_data('../../data/query/{}.jsonl'.format(user_name))
# numbers = extract_numbers_from_filenames(directory)
# with get_openai_callback() as cb:
# for number in tqdm(numbers[:10]):
# print(number)
# json_data = json.load(open(os.path.join(directory, 'annotation_{}.json'.format(number))))
# human_collected_info_data = json.load(open(os.path.join(directory, 'human_collected_info_{}.json'.format(number))))
# query_data = query_data_list[number-1]
# planner_results = planner.run(human_collected_info_data, query_data['query'])
# org_result = json.load(open(os.path.join('../../results/turbo16k-turbo16k/{}/plan_{}.json'.format(user_name,number))))
# # org_result.append({})
# org_result[-1]['chatgpt_human_collected_info_results'] = planner_results
# # write to json file
# # with open(os.path.join('../../results/turbo16k-turbo16k/{}/plan_{}.json'.format(user_name,number)), 'w') as f:
# # json.dump(org_result, f, indent=4)
# print(cb)
if __name__ == "__main__":
model_name=['gpt-3.5-turbo-1106','gpt-4-1106-preview','gemini','mixtral'][1]
set_type = ['dev','test'][0]
method = ['direct','cot','react','reflexion'][0]
directory = f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}'
query_data_list = load_line_json_data(os.path.join(directory, 'query/query.jsonl'))
numbers = [i for i in range(1,len(query_data_list)+1)]
if method == 'direct':
planner = Planner(model_name=model_name, agent_prompt=planner_agent_prompt)
elif method == 'cot':
planner = Planner(model_name=model_name, agent_prompt=cot_planner_agent_prompt)
elif method == 'react':
planner = ReactPlanner(model_name=model_name, agent_prompt=react_planner_agent_prompt)
elif method == 'reflexion':
planner = ReactReflectPlanner(model_name=model_name, agent_prompt=react_reflect_planner_agent_prompt,reflect_prompt=reflect_prompt)
with get_openai_callback() as cb:
for number in tqdm(numbers[:]):
# print(number)
# json_data = json.load(open(os.path.join(directory, 'plan/annotation_{}.json'.format(number))))
human_collected_info_data = json.load(open(os.path.join(directory, 'plan/human_collected_info_{}.json'.format(number))))
query_data = query_data_list[number-1]
while True:
if method in ['react','reflexion']:
planner_results, scratchpad = planner.run(human_collected_info_data, query_data['query'])
else:
planner_results = planner.run(human_collected_info_data, query_data['query'])
if planner_results != None:
break
print(planner_results)
# check if the directory exists
if not os.path.exists(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}')):
os.makedirs(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}'))
if not os.path.exists(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{number}.json')):
result = [{}]
else:
result = json.load(open(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{number}.json')))
if method in ['react','reflexion']:
result[-1][f'{model_name}_{method}_collected_info_results_logs'] = scratchpad
result[-1][f'{model_name}_{method}_collected_info_results'] = planner_results
# write to json file
with open(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{number}.json'), 'w') as f:
json.dump(result, f, indent=4)
print(cb)