mislavb's picture
wip
d6bfc1a
raw
history blame
5.04 kB
import json
import openai
from gradio import ChatMessage
class SantaAgent:
def __init__(self, system_prompt: str):
self.system_prompt = system_prompt
self.client = openai.OpenAI()
self.tools = [
{
"type": "function",
"function": {
"name": "buy_item",
"description": "Buy an item from the store.",
"parameters": {
"type": "object",
"properties": {
"item": {
"type": "string",
"description": "The item to buy from the store."
}
},
"required": ["item"]
}
}
},
{
"type": "function",
"function": {
"name": "give_present",
"description": "Give a present to a person.",
"parameters": {
"type": "object",
"properties": {
"person": {
"type": "string",
"description": "The person to give the present to."
},
"item": {
"type": "string",
"description": "The item to give to the person."
}
},
"required": ["person", "item"]
}
}
},
{
"type": "function",
"function": {
"name": "stop",
"description": "Use this tool if you are finished and want to stop."
}
}
]
def buy_item(self, item: str):
"""Buy an item from the store."""
return f"Bought {item} from the store."
def give_present(self, person: str, item: str):
"""Give a present to a person."""
return f"Gave {item} to {person}."
def stop(self):
return "STOP"
def mock_run_santa_agent(self):
messages = [
{"role": "user", "content": "Hi there"},
{"role": "assistant", "content": "Bye bye"},
]
gradio_messages = [
ChatMessage(role="user", content="Hi there"),
ChatMessage(role="assistant", content="Bye bye"),
]
return messages, gradio_messages
def run_santa_agent(self, user_prompt: str):
"""Run the Santa agent."""
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_prompt},
]
gradio_messages = [
ChatMessage(role="system", content=self.system_prompt),
ChatMessage(role="user", content=user_prompt),
]
while True:
response = self.client.chat.completions.create(
messages=messages,
model="gpt-4o-mini",
tools=self.tools,
tool_choice="auto",
)
messages.append(response.choices[0].message.to_dict())
content = response.choices[0].message.content
if content is not None:
gradio_messages.append(ChatMessage(role="assistant", content=content))
tool_calls = response.choices[0].message.tool_calls
should_stop = False
if tool_calls:
for tool_call in tool_calls:
arguments = json.loads(tool_call.function.arguments)
if tool_call.function.name == "buy_item":
item = arguments["item"]
gradio_messages.append(ChatMessage(role="assistant", content=f"buy_item({item})", metadata={"title": "πŸ”§ Tool Call: buy_item"}))
output = self.buy_item(item)
elif tool_call.function.name == "give_present":
person, item = arguments["person"], arguments["item"]
gradio_messages.append(ChatMessage(role="assistant", content=f"give_present({person}, {item})", metadata={"title": "πŸ”§ Tool Call: give_present"}))
output = self.give_present(person, item)
elif tool_call.function.name == "stop":
output = self.stop()
should_stop = True
messages.append({"role": "tool", "content": output, "tool_call_id": tool_call.id})
if not should_stop:
gradio_messages.append(ChatMessage(role="assistant", content=output, metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
if should_stop or len(messages) > 10:
break
return messages, gradio_messages