Spaces:
Sleeping
Sleeping
Yannael_LB
Merge branch 'main' of https://huggingface.co/spaces/yannael/openai-assistant-wanderlust into main
95b65d8
# -*- coding: utf-8 -*- | |
# OpenAI assistant Wanderlust | |
# Basic recreation of OpenAI's DevDay Wanderlust demo app. It relies on Gradio and the new Assistants API. | |
# This space is inspired by the implementation of Fanilo Andrianasolo using Streamlit - https://www.youtube.com/watch?v=tLeqCDKgEDU | |
#""" | |
#!pip install -q -U gradio openai datasets | |
import gradio as gr | |
import random | |
import openai | |
import os | |
import json | |
import plotly.graph_objects as go | |
import time | |
#os.environ["OPENAI_API_KEY"] = "sk-..." # Replace with your key | |
assistant_id = "asst_7OC3NTeyCjEZrApdLRklplE7" | |
#And uncomment the following line | |
#assistant = client.beta.assistants.retrieve(assistant_id) | |
client = openai.OpenAI() | |
####################################### | |
# TOOLS SETUP | |
####################################### | |
def update_map_state(map_metadata_state, latitude, longitude, zoom): | |
"""OpenAI tool to update map in-app | |
""" | |
map_metadata_state["latitude"] = latitude | |
map_metadata_state["longitude"] = longitude | |
map_metadata_state["zoom"] = zoom | |
print(map_metadata_state) | |
return "Map updated" | |
def add_markers_state(map_metadata_state, latitudes, longitudes, labels): | |
"""OpenAI tool to update markers in-app | |
""" | |
map_metadata_state["lat"] = latitudes | |
map_metadata_state["lon"] = longitudes | |
map_metadata_state["text"] = labels | |
print(map_metadata_state) | |
return "Markers added" | |
tool_to_function = { | |
"update_map": update_map_state, | |
"add_markers": add_markers_state | |
} | |
## Helpers | |
def submit_message(assistant_id, thread_id, user_message): | |
client.beta.threads.messages.create( | |
thread_id=thread_id, role="user", content=user_message | |
) | |
run = client.beta.threads.runs.create( | |
thread_id=thread_id, | |
assistant_id=assistant_id, | |
) | |
return run | |
def get_run_info(run_id, thread_id): | |
run = client.beta.threads.runs.retrieve( | |
thread_id=thread_id, | |
run_id=run_id, | |
) | |
return run | |
####################################### | |
# INITIAL DATA FOR MAP | |
####################################### | |
map_metadata = { | |
"latitude": 48.85, | |
"longitude": 2.35, | |
"zoom": 12, | |
"lat": [], | |
"lon": [], | |
"text": [], | |
} | |
fig = go.Figure(go.Scattermapbox()) | |
fig.update_layout( | |
mapbox_style="open-street-map", | |
hovermode='closest', | |
mapbox=dict( | |
center=go.layout.mapbox.Center( | |
lat=map_metadata["latitude"], | |
lon=map_metadata["longitude"] | |
), | |
zoom=map_metadata["zoom"] | |
), | |
) | |
def respond(message, chat_history, thread, map_metadata_state): | |
if thread is None: | |
thread = client.beta.threads.create() | |
print(map_metadata_state) | |
run = submit_message(assistant.id, thread.id, message) | |
completed = False | |
# Polling | |
while not completed: | |
run = get_run_info(run.id, thread.id) | |
if run.status == "requires_action": | |
tools_output = [] | |
for tool_call in run.required_action.submit_tool_outputs.tool_calls: | |
f = tool_call.function | |
f_name = f.name | |
f_args = json.loads(f.arguments) | |
#print(f"Launching function {f_name} with args {f_args}") | |
tool_result = tool_to_function[f_name](map_metadata_state,**f_args) | |
tools_output.append( | |
{ | |
"tool_call_id": tool_call.id, | |
"output": tool_result, | |
} | |
) | |
#print(f"Will submit {tools_output}") | |
client.beta.threads.runs.submit_tool_outputs( | |
thread_id=thread.id, | |
run_id=run.id, | |
tool_outputs=tools_output, | |
) | |
if run.status == "completed": | |
completed = True | |
else: | |
time.sleep(0.1) | |
dialog = [ | |
[m.role, m.content[0].text.value] | |
for m in client.beta.threads.messages.list(thread.id, order="asc").data | |
] | |
formatted_dialog = [] | |
for i in range(int(len(dialog)/2)): | |
formatted_dialog.append([dialog[i*2][1],dialog[i*2+1][1]]) | |
chat_history = formatted_dialog | |
print(formatted_dialog) | |
fig = None | |
if len(map_metadata_state["lat"])==0: | |
fig = go.Figure(go.Scattermapbox()) | |
else : | |
fig = go.Figure(go.Scattermapbox( | |
customdata=map_metadata_state["text"], | |
lat=map_metadata_state["lat"], | |
lon=map_metadata_state["lon"], | |
mode='markers', | |
marker=go.scattermapbox.Marker( | |
size=18 | |
), | |
hoverinfo="text", | |
hovertemplate='<b>Name</b>: %{customdata}' | |
)) | |
fig.update_layout( | |
mapbox_style="open-street-map", | |
hovermode='closest', | |
mapbox=dict( | |
center=go.layout.mapbox.Center( | |
lat=map_metadata_state["latitude"], | |
lon=map_metadata_state["longitude"] | |
), | |
zoom=12 | |
), | |
) | |
return "", chat_history, fig, thread, map_metadata_state | |
with gr.Blocks(title="OpenAI assistant Wanderlust") as demo: | |
gr.Markdown("# OpenAI assistant Wanderlust") | |
thread = gr.State() | |
map_metadata_state = gr.State(map_metadata) | |
with gr.Column(): | |
with gr.Row(): | |
chatbot = gr.Chatbot() | |
map = gr.Plot(fig) | |
msg = gr.Textbox("Move the map to Brussels and add markers on shops with the best waffles") | |
with gr.Column(): | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
clear = gr.ClearButton([msg, chatbot]) | |
msg.submit(respond, [msg, chatbot, thread, map_metadata_state], [msg, chatbot, map, thread, map_metadata_state]) | |
submit.click(respond, [msg, chatbot, thread, map_metadata_state], [msg, chatbot, map, thread, map_metadata_state]) | |
gr.Markdown( | |
""" | |
# Description | |
Basic recreation of OpenAI's DevDay Wanderlust demo app. It relies on Gradio and the new Assistants API. [Github repository](https://github.com/Yannael/openai-assistant-wanderlust) | |
This space is inspired by the implementation of [Fanilo Andrianasolo using Streamlit](https://www.youtube.com/watch?v=tLeqCDKgEDU) | |
""" | |
) | |
demo.queue() | |
demo.launch(debug=True) | |