Spaces:
Running
Running
import sys | |
import os | |
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) | |
os.chdir(os.path.dirname(os.path.abspath(__file__))) | |
import gradio as gr | |
import json | |
from tools.flights.apis import Flights | |
from tools.accommodations.apis import Accommodations | |
from tools.restaurants.apis import Restaurants | |
from tools.googleDistanceMatrix.apis import GoogleDistanceMatrix | |
from tools.attractions.apis import Attractions | |
import pandas as pd | |
from gradio.events import Changeable, Streamable, Submittable, on | |
from src.utils import * | |
from gradio.events import EventListenerMethod | |
import math | |
data = load_line_json_data('query.jsonl') | |
query_dropdown_text = [f"Query {index}" for index in range(1, len(data)+1)] | |
with open('assets/styles.css', 'r') as css_file: | |
css_content = css_file.read() | |
flight = Flights() | |
accommodations = Accommodations() | |
restaurants = Restaurants() | |
googleDistanceMatrix = GoogleDistanceMatrix() | |
attractions = Attractions() | |
annotation_idx = 4 | |
cost = 0 | |
global_user_name = "" | |
transportaion_val = "" | |
accommodation_val = "" | |
breakfast_val = "" | |
lunch_val = "" | |
dinner_val = "" | |
attraction_val = "" | |
def auto_update_current_cost(*components): | |
global cost | |
cost = 0 | |
for i in range(7): | |
# flight | |
if judge_valid_info(components[7+i]): | |
if components[7+i] != '-': | |
if 'Flight Number' in components[7+i]: | |
cost += flight.data[flight.data['Flight Number'] == components[7+i].split('Flight Number: ')[1].split(',')[0]]['Price'].values[0] * data[annotation_idx-1]['people_number'] | |
elif 'Self-driving' in components[7+i] or 'Taxi' in components[7+i]: | |
if 'Self-driving' in components[7+i]: | |
cost += int(components[7+i].split('cost: ')[1]) * math.ceil(data[annotation_idx-1]['people_number'] * 1.0 / 5) | |
else: | |
cost += int(components[7+i].split('cost: ')[1]) * math.ceil(data[annotation_idx-1]['people_number'] * 1.0 / 4) | |
# breakfast | |
if judge_valid_info(components[7*2+i]): | |
if components[7*2+i] != '-': | |
name, city = get_valid_name_city(components[7*2+i]) | |
cost += restaurants.data[(restaurants.data['Name'] == name) & (restaurants.data['City'] == city)]['Average Cost'].values[0] * data[annotation_idx-1]['people_number'] | |
# attraction | |
# lunch | |
if judge_valid_info(components[7*4+i]): | |
if components[7*4+i] != '-': | |
name, city = get_valid_name_city(components[7*4+i]) | |
cost += restaurants.data[(restaurants.data['Name'] == name) & (restaurants.data['City'] == city)]['Average Cost'].values[0] * data[annotation_idx-1]['people_number'] | |
# dinner | |
if judge_valid_info(components[7*5+i]): | |
if components[7*5+i] != '-': | |
name, city = get_valid_name_city(components[7*5+i]) | |
cost += restaurants.data[(restaurants.data['Name'] == name) & (restaurants.data['City'] == city)]['Average Cost'].values[0] * data[annotation_idx-1]['people_number'] | |
# accommodation | |
if judge_valid_info(components[7*6+i]): | |
if components[7*6+i] != '-': | |
name, city = get_valid_name_city(components[7*6+i]) | |
res = accommodations.data[(accommodations.data['NAME'] == name) & (accommodations.data['city'] == city)] | |
cost += res['price'].values[0] * math.ceil(data[annotation_idx-1]['people_number'] * 1.0 / res['maximum occupancy'].values[0]) | |
return gr.Number(label="Current Cost", value=cost, interactive=False,elem_classes="current_cost") | |
def auto_update_current_cost_number(fn, input_components, output_components): | |
events: list[EventListenerMethod] = [] | |
for component in input_components: | |
if isinstance(component, Streamable) and component.streaming: | |
events.append(component.stream) | |
elif isinstance(component, Changeable): | |
events.append(component.change) | |
on( | |
events, | |
fn, | |
input_components, | |
output_components | |
) | |
def get_google_distance_matrix(mode,departure_city, destination_city): | |
if mode == "Self-driving": | |
return pd.DataFrame(googleDistanceMatrix.run_for_annotation(departure_city, destination_city,'driving'),index=[0]) | |
elif mode == "Taxi": | |
return pd.DataFrame(googleDistanceMatrix.run_for_annotation(departure_city, destination_city,'taxi'),index=[0]) | |
def return_list_items(items): | |
return items[0],items[1],items[2],items[3],items[4],items[5],items[6] | |
def submit_update_options(query_candidate, *components): | |
global query_dropdown_text ,cost | |
global transportaion_val, accommodation_val, breakfast_val, lunch_val, dinner_val, attraction_val | |
json_data = [{} for i in range(7)] | |
new_components = [ None for i in range(len(components))] | |
# write the annotation into the json file | |
cuisine_set = set() | |
for i in range(data[annotation_idx-1]['days']): | |
json_data[i]['days'] = i+1 | |
if judge_submit_info(components[i],i+1,'current city',data[annotation_idx-1]): | |
json_data[i]['current_city'] = components[i] | |
if judge_submit_info(components[7+i],i+1,'transportation',data[annotation_idx-1]): | |
json_data[i]['transportation'] = components[7+i] | |
if judge_submit_info(components[7*2+i],i+1,'breakfast',data[annotation_idx-1]): | |
cuisine_set = judge_valid_cuisine(components[7*2+i],data[annotation_idx-1],restaurants.data, cuisine_set) | |
json_data[i]['breakfast'] = components[7*2+i] | |
if judge_submit_info(components[7*3+i],i+1,'attraction',data[annotation_idx-1]): | |
json_data[i]['attraction'] = components[7*3+i] | |
if judge_submit_info(components[7*4+i],i+1,'lunch',data[annotation_idx-1]): | |
cuisine_set = judge_valid_cuisine(components[7*4+i],data[annotation_idx-1],restaurants.data, cuisine_set) | |
json_data[i]['lunch'] = components[7*4+i] | |
if judge_submit_info(components[7*5+i],i+1,'dinner',data[annotation_idx-1]): | |
cuisine_set = judge_valid_cuisine(components[7*5+i],data[annotation_idx-1],restaurants.data, cuisine_set) | |
json_data[i]['dinner'] = components[7*5+i] | |
if judge_submit_info(components[7*6+i],i+1,'accommodation',data[annotation_idx-1],accommodations.data): | |
json_data[i]['accommodation'] = components[7*6+i] | |
if data[annotation_idx-1]['local_constraint']['cuisine'] and len(list(cuisine_set)) != len(data[annotation_idx-1]['local_constraint']['cuisine']): | |
raise gr.Error("Please note the cuisine constraints.") | |
consectutive_accommodation = count_consecutive_values(components[7*6:7*6+data[annotation_idx-1]['days']]) | |
# print(consectutive_accommodation) | |
# print("") | |
for unit in consectutive_accommodation: | |
# print(unit) | |
if unit and unit[0] not in ['-',''] and unit!="You don't need to fill in the information for this or later days.": | |
name, city = get_valid_name_city(unit[0]) | |
if unit[1] < accommodations.data[(accommodations.data['NAME'] == name) & (accommodations.data['city'] == city)].iloc[0]['minimum nights']: | |
raise gr.Error(f"Please note the minimum nights of {unit[0]}.") | |
if cost > data[annotation_idx-1]['budget']: | |
raise gr.Error("The current cost is higher than the budget, please modify the information.") | |
# return [gr.Textbox(value="The current cost is higher than the budget, please modify the information.",interactive=False)]*6 + [gr.Dropdown(choices=query_dropdown_text)] | |
# ['is_reasonalbe_visiting_city', 'is_valid_restaurants', 'is_valid_attractions', 'is_valid_transportation', 'is_valid_information_in_current_city', 'is_valid_information_in_sandbox'] | |
for i in range(data[annotation_idx-1]['days']): | |
new_components[i] = gr.Textbox(label="Current City", value=None,interactive=False) | |
new_components[7+i] = gr.Textbox(label="Transportation", value="-",interactive=False) | |
new_components[7*2+i] = gr.Textbox(label="Breakfast", value=None, interactive=False) | |
new_components[7*3+i] = gr.Textbox(label="Attractions", value=None, interactive=False) | |
new_components[7*4+i] = gr.Textbox(label="Lunch", value=None,interactive=False) | |
new_components[7*5+i] = gr.Textbox(label="Dinner", value=None,interactive=False) | |
new_components[7*6+i] = gr.Textbox(label="Accommodation", value=None,interactive=False) | |
with open('../data/annotation/{}/annotation_{}.json'.format(global_user_name,annotation_idx), 'w') as f: | |
json.dump(json_data, f) | |
transportaion_val = "" | |
accommodation_val = "" | |
breakfast_val = "" | |
lunch_val = "" | |
dinner_val = "" | |
attraction_val = "" | |
new_options = query_dropdown_text | |
new_options[annotation_idx-1] = f"Query {annotation_idx} --- Done" | |
query_dropdown_text = new_options | |
return [gr.Dropdown(choices=new_options)] + new_components + [gr.Textbox(value=None) for _ in range(14)] | |
def update_gloabal_idx(selected_option): | |
global annotation_idx,cost | |
global transportaion_val, accommodation_val, breakfast_val, lunch_val, dinner_val, attraction_val | |
transportaion_val = "" | |
accommodation_val = "" | |
breakfast_val = "" | |
lunch_val = "" | |
dinner_val = "" | |
attraction_val = "" | |
annotation_idx = extract_query_number(selected_option) | |
new_components = [ gr.Textbox() for i in range(49)] | |
city_list = get_city_list(data[annotation_idx-1]['days'], data[annotation_idx-1]['org'], data[annotation_idx-1]['dest']) | |
city_drop_down = [None for i in range(63)] | |
date_component_list = [gr.Textbox(value="", interactive=False) for i in range(7)] | |
for i in range(data[annotation_idx-1]['days']): | |
new_components[i] = gr.Textbox(label="Current City",value="",interactive=False) | |
new_components[7+i] = gr.Textbox(label="Transportation",value="",interactive=False) | |
new_components[7*2+i] = gr.Textbox(label="Breakfast", value="",interactive=False) | |
new_components[7*3+i] = gr.Textbox(label="Attractions", value="",interactive=False) | |
new_components[7*4+i] = gr.Textbox(label="Lunch", value="",interactive=False) | |
new_components[7*5+i] = gr.Textbox(label="Dinner", value="",interactive=False) | |
new_components[7*6+i] = gr.Textbox(label="Accommodation", value="",interactive=False) | |
for i in range(63): | |
if i % 9 in [0,2,4]: | |
city_drop_down[i] = gr.Radio(label="Departure City",choices=city_list) | |
elif i % 9 in [1,3,5]: | |
city_drop_down[i] = gr.Radio(label="Destination City",choices=city_list) | |
else: | |
city_drop_down[i] = gr.Radio(label="City",choices=city_list) | |
for i in range(data[annotation_idx-1]['days']): | |
date_component_list[i] = gr.Textbox(label="Date", value=data[annotation_idx-1]['date'][i], interactive=False) | |
return [create_data_display(css_content,data,annotation_idx),gr.update(visible=True)]+new_components+city_drop_down+date_component_list+[gr.Number(value=cost, interactive=False)] | |
def Recommend_new(x): | |
return x + ": Highest Cosine Similarity" | |
def current_city_selection(departure_city, dest_city): | |
final_city = "" | |
if departure_city == dest_city: | |
final_city = departure_city | |
else: | |
final_city = "from " + departure_city + " to " + dest_city | |
return gr.Textbox(label="Current City", value=final_city, interactive=False) | |
def tab_hint(current_city,transportaion,breakfast,attraction,lunch,dinner,accommodation,evt:gr.SelectData): | |
global transportaion_val, accommodation_val, breakfast_val, lunch_val, dinner_val, attraction_val | |
transportaion_val = "" | |
accommodation_val = "" | |
breakfast_val = breakfast | |
lunch_val = lunch | |
dinner_val = dinner | |
attraction_val = attraction | |
if data[annotation_idx-1]['days'] <= evt.index: | |
return [gr.Textbox(value="You don't need to fill in the information for this or later days.",interactive=False)]*7 + [gr.Dropdown(choices=query_dropdown_text)] + [gr.Textbox(value="") for _ in range(2)] | |
else: | |
return current_city, transportaion,breakfast,attraction,lunch,dinner,accommodation, gr.Dropdown(choices=query_dropdown_text),gr.Textbox(value=""), gr.Textbox(value="") | |
css = """ | |
#submit_button {background: #00a1b8;} | |
#clear_button {background: #ff8069;} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(f""" | |
<div style="text-align: center;"> | |
<h2>TravelBench Environment</h2> | |
</div> | |
""") | |
gr.HTML("<hr>") | |
query_selection = gr.Dropdown(query_dropdown_text, label="Query") | |
annotation_button = gr.ClearButton(value="Search") | |
gr.HTML("<hr>") | |
transportaion_list = [] | |
breakfast_list = [] | |
attraction_list = [] | |
lunch_list = [] | |
dinner_list = [] | |
accommodation_list = [] | |
current_city_list = [] | |
city_component_list = [] | |
date_component_list = [] | |
filtering_keyword_list = [] | |
with gr.Column(visible=False) as component: | |
query_display = gr.HTML(create_data_display(css_content,data,annotation_idx)) | |
current_cost = gr.Number(label="Current Cost", value=0, interactive=False,elem_classes="current_cost") | |
with gr.Row(): | |
tabs = gr.Tabs() | |
with tabs: | |
for i in range(7): | |
with gr.Tab(label=f"Day {i+1}", id=(i+1)) as tab: | |
with gr.Row(): | |
with gr.Column(): | |
current_city = gr.Textbox(label="Current City", interactive=False,value="") | |
transportaion = gr.Textbox(label="Transportation",interactive=False,value="") | |
breakfast = gr.Textbox(label="Breakfast", interactive=False,value="") | |
attraction = gr.Textbox(label="Attractions",interactive=False,value="") | |
lunch = gr.Textbox(label="Lunch", interactive=False,value="") | |
dinner = gr.Textbox(label="Dinner", interactive=False,value="") | |
accommodation = gr.Textbox(label="Accommodation",interactive=False,value="") | |
current_city_list.append(current_city) | |
transportaion_list.append(transportaion) | |
breakfast_list.append(breakfast) | |
attraction_list.append(attraction) | |
lunch_list.append(lunch) | |
dinner_list.append(dinner) | |
accommodation_list.append(accommodation) | |
with gr.Column(): | |
with gr.Tab(label=f'City Selection'): | |
departure_city = gr.Radio(label="Departure City") | |
destination_city = gr.Radio(label="Destination City") | |
city_component_list += [departure_city,destination_city] | |
btn = gr.Button("Done") | |
# output = gr.HTML("") | |
annotation_button.add(components=[departure_city,destination_city]) | |
btn.click(current_city_selection,inputs=[departure_city,destination_city],outputs=current_city) | |
# output.select(on_row_selected, inputs=[city],outputs=attraction) | |
# btn_clear.click(clear,inputs=None,outputs=attraction) | |
with gr.Tab(label=f'Flight'): | |
transportaion_val = "" | |
def on_row_selected(dep_city,dest_city,evt:gr.SelectData): | |
global transportaion_val | |
if evt.index[1] == 0: | |
transportaion_val = "Flight Number: " + str(evt.value) + f", from {dep_city} to {dest_city}, Departure Time: {flight.data[flight.data['Flight Number']==evt.value].iloc[0]['DepTime']}, Arrival Time: {flight.data[flight.data['Flight Number']==evt.value].iloc[0]['ArrTime']}" | |
return transportaion_val | |
else: | |
return transportaion_val | |
departure_city = gr.Radio(label="Departure City") | |
destination_city = gr.Radio(label="Destination City") | |
city_component_list += [departure_city, destination_city] | |
date = gr.Textbox(label="Date", interactive=False) | |
date_component_list.append(date) | |
btn = gr.Button("Search") | |
clean_btn = gr.Button("Skip") | |
output = gr.Dataframe(flight.run_for_annotation('','',''), interactive=False) | |
annotation_button.add(components=[output]) | |
btn.click(flight.run_for_annotation,inputs=[departure_city, destination_city, date],outputs=output) | |
output.select(on_row_selected, inputs=[departure_city, destination_city],outputs=[transportaion]) | |
clean_btn.click(lambda:"-",inputs=None,outputs=[transportaion]) | |
with gr.Tab(label='Driving'): | |
def on_row_selected(org, dest, mode,data, evt:gr.SelectData): | |
global transportaion_val | |
data_dict = data.to_dict(orient = 'split')['data'][0] | |
transportaion_val = f"{mode}, from {org} to {dest}, duration: {data_dict[2]}, distance: {data_dict[1]}, cost: {data_dict[0]}" | |
return transportaion_val | |
transportation_mode = gr.Radio(["Self-driving", "Taxi"], label="Mode") | |
departure_city = gr.Radio(label="Departure City") | |
destination_city = gr.Radio(label="Destination City") | |
city_component_list += [departure_city, destination_city] | |
btn = gr.Button("Search") | |
clean_btn = gr.Button("Skip") | |
output = gr.Dataframe( | |
headers=["duration", "distance", "cost"], | |
datatype=["str", "str", "str"], | |
value=["", "", ""], | |
row_count=1, | |
col_count=(3, "fixed"),interactive=False | |
) | |
annotation_button.add(components=[output]) | |
btn.click(get_google_distance_matrix, inputs=[transportation_mode, departure_city, destination_city],outputs=output) | |
clean_btn.click(lambda:"-",inputs=None,outputs=[transportaion]) | |
output.select(on_row_selected, inputs=[departure_city,destination_city,transportation_mode,output],outputs=transportaion) | |
with gr.Tab(label='Accommodation'): | |
accommodation_val = "" | |
def on_row_selected(city,evt:gr.SelectData): | |
global accommodation_val | |
if evt.index[1] == 0: | |
accommodation_val = str(evt.value) + f", {city}" | |
return accommodation_val | |
def search(city, filtering_keyword=None): | |
# if filtering_keyword in [" ",""]: | |
return accommodations.run_for_annotation(city) | |
# else: | |
# return_data = accommodations.run_for_annotation(city) | |
# return return_data[~return_data['house_rules'].str.contains(filtering_keyword)] | |
city = gr.Radio(label="City") | |
# filtering_keyword = gr.Textbox(label="Forbidden keyword (only for hotel rule)",value="") | |
# filtering_keyword_list.append(filtering_keyword) | |
city_component_list += [city] | |
btn = gr.Button("Search") | |
skip_btn = gr.Button("Skip") | |
output = gr.Dataframe(accommodations.run_for_annotation(''), interactive=False) | |
annotation_button.add(components=[output]) | |
btn.click(search,inputs=[city],outputs=output) | |
output.select(on_row_selected, inputs=[city],outputs=accommodation) | |
skip_btn.click(lambda:"-",inputs=None,outputs=accommodation) | |
with gr.Tab(label='Restaurant'): | |
breakfast_val = "" | |
lunch_val = "" | |
dinner_val = "" | |
def on_row_selected(city, meal_type, evt: gr.SelectData): | |
global breakfast_val, lunch_val, dinner_val | |
if evt.index[1] == 0: | |
if meal_type == "Breakfast": | |
breakfast_val = str(evt.value) + f", {city}" | |
elif meal_type == "Lunch": | |
lunch_val = str(evt.value) + f", {city}" | |
elif meal_type == "Dinner": | |
dinner_val = str(evt.value) + f", {city}" | |
return breakfast_val, lunch_val, dinner_val | |
def skip(meal_type): | |
global breakfast_val, lunch_val, dinner_val | |
if meal_type == "Breakfast": | |
breakfast_val = "-" | |
elif meal_type == "Lunch": | |
lunch_val = "-" | |
elif meal_type == "Dinner": | |
dinner_val = "-" | |
return breakfast_val, lunch_val, dinner_val | |
def search(city, filtering_keyword=None): | |
# if filtering_keyword in [" ",""]: | |
return restaurants.run_for_annotation(city) | |
# else: | |
# return_data = restaurants.run_for_annotation(city) | |
# return return_data[return_data['Cuisines'].str.contains(filtering_keyword)] | |
city = gr.Radio(label="City") | |
# filtering_keyword = gr.Textbox(label="Contain keyword",value="") | |
# filtering_keyword_list.append(filtering_keyword) | |
city_component_list += [city] | |
btn = gr.Button("Search") | |
skip_btn = gr.Button("Skip") | |
choice = gr.Radio(["Breakfast", "Lunch", "Dinner"], label="Meal Type") | |
annotation_button.add(components=[choice]) | |
output = gr.Dataframe(restaurants.run_for_annotation(''), interactive=False) | |
annotation_button.add(components=[output]) | |
btn.click(search,inputs=[city],outputs=output) | |
# according to the meal type to select the outputs - breakfast, lunch, dinner | |
output.select(on_row_selected, inputs=[city,choice],outputs=[breakfast, lunch, dinner]) | |
skip_btn.click(skip,inputs=[choice],outputs=[breakfast, lunch, dinner]) | |
with gr.Tab(label='Attraction'): | |
attraction_val = "" | |
def on_row_selected(city, evt:gr.SelectData): | |
global attraction_val | |
attraction_val = attraction_val.replace("-","") | |
if evt.index[1] == 0: | |
attraction_val += evt.value + ', '+ city + ";" | |
return attraction_val | |
def clear(): | |
global attraction_val | |
attraction_val = "" | |
return attraction_val | |
city = gr.Radio(label="City") | |
city_component_list += [city] | |
btn = gr.Button("Search") | |
btn_clear = gr.Button("Clear") | |
btn_skip = gr.Button("Skip") | |
output = gr.Dataframe(attractions.run_for_annotation(''), interactive=False) | |
annotation_button.add(components=[output]) | |
btn.click(attractions.run_for_annotation,inputs=[city],outputs=output) | |
output.select(on_row_selected, inputs=[city],outputs=attraction) | |
btn_clear.click(clear,inputs=None,outputs=attraction) | |
btn_skip.click(lambda:"-",inputs=None,outputs=attraction) | |
tab.select(tab_hint, inputs=[current_city,transportaion,breakfast,attraction,lunch,dinner,accommodation],outputs=[current_city, transportaion,breakfast,attraction,lunch,dinner,accommodation, query_selection]+filtering_keyword_list[-2:]) | |
component_list = current_city_list + transportaion_list + breakfast_list + attraction_list + lunch_list + dinner_list + accommodation_list | |
# clear_button = gr.ClearButton(components=component_list,value="Clear") | |
input_args = [query_selection] + component_list | |
auto_update_current_cost_number(auto_update_current_cost, component_list, [current_cost]) | |
# b3.click(Recommend_new, inputs=query_selection, outputs=new_title) | |
annotation_button.add(components=city_component_list) | |
annotation_button.click(update_gloabal_idx, inputs=[query_selection], outputs=[query_display,component]+component_list+city_component_list+date_component_list+[current_cost]) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |