kn404's picture
limit non-toolcall messages, better assertion messages
f448621
raw
history blame
16 kB
import json
import openai
from gradio import ChatMessage
class SantaAgent:
def __init__(self, system_prompt: str):
self.system_prompt = system_prompt
self.client = openai.OpenAI()
self.tools = [
{
"type": "function",
"function": {
"name": "buy_item",
"description": "Buy an item from the store.",
"parameters": {
"type": "object",
"properties": {
"item": {
"type": "string",
"description": "The item to buy from the store."
}
},
"required": ["item"]
}
}
},
{
"type": "function",
"function": {
"name": "give_present",
"description": "Give a present to a person.",
"parameters": {
"type": "object",
"properties": {
"person": {
"type": "string",
"description": "The person to give the present to."
},
"item": {
"type": "string",
"description": "The item to give to the person."
}
},
"required": ["person", "item"]
}
}
},
{
"type": "function",
"function": {
"name": "make_naughty_nice_list",
"description": "Make a list of children that have been naughty and nice. This function cannot make other lists.",
}
},
{
"type": "function",
"function": {
"name": "check_naughty_nice_list",
"description": "Check which children have been naughty and nice. This is the only information in the list.",
}
},
{
"type": "function",
"function": {
"name": "cut_paper",
"description": "Cut wrapping paper to wrap a present.",
}
},
{
"type": "function",
"function": {
"name": "find_end_of_tape",
"description": "Find the end of the tape to wrap a present.",
}
},
{
"type": "function",
"function": {
"name": "wrap_present",
"description": "Wrap a present.",
}
},
{
"type": "function",
"function": {
"name": "label_present",
"description": "Label a present with the recipient's name.",
"parameters": {
"type": "object",
"properties": {
"recipient": {
"type": "string",
"description": "The name of the recipient."
}
},
"required": ["recipient"]
}
}
},
{
"type": "function",
"function": {
"name": "retrieve_letters",
"description": "Retrieve letters from children."
}
},
{
"type": "function",
"function": {
"name": "check_temperature",
"description": "Use this tool to check the temperature of an object.",
"parameters": {
"type": "object",
"properties": {
"object": {
"type": "string",
"description": "The object to check the temperature of."
}
},
"required": ["object"]
}
}
},
{
"type": "function",
"function": {
"name": "dunk_cookie",
"description": "Dunk a cookie in milk."
}
},
{
"type": "function",
"function": {
"name": "drink",
"description": "Drink an item.",
"parameters": {
"type": "object",
"properties": {
"item": {
"type": "string",
"description": "The item to drink."
}
},
"required": ["item"]
}
}
},
{
"type": "function",
"function": {
"name": "put_route_into_maps",
"description": "Put a route into Google Maps.",
"parameters": {
"type": "object",
"properties": {
"addr1": {
"type": "string",
"description": "First Address to Visit."
},
"addr2": {
"type": "string",
"description": "Second Address to Visit."
},
"addr3": {
"type": "string",
"description": "Third Address to Visit."
}
},
"required": ["addr1", "addr2", "addr3"]
}
}
},
{
"type": "function",
"function": {
"name": "stop",
"description": "Use this tool if you are finished and want to stop."
}
}
]
def buy_item(self, item: str):
"""Buy an item from the store."""
return f"Bought {item} from the store."
def give_present(self, person: str, item: str):
"""Give a present to a person."""
return f"Gave {item} to {person}."
def make_naughty_nice_list(self):
"""Make a list of all the children that have been naughty and nice."""
return "Made a list."
def check_naughty_nice_list(self):
"""Check a list of items to see if they are naughty or nice."""
return json.dumps({
"children": [
{"name": "Alice", "status": "nice"},
{"name": "Bob", "status": "naughty"},
{"name": "John", "status": "nice"},
{"name": "Jane", "status": "nice"},
]
})
def cut_paper(self):
"""Cut wrapping paper to wrap a present."""
return "Cut the wrapping paper."
def find_end_of_tape(self):
"""Find the end of the tape to wrap a present."""
return "Found the end of the tape."
def wrap_present(self):
"""Wrap a present."""
return "Wrapped the present."
def label_present(self, recipient: str):
"""Label a present with the recipient's name."""
return f"Labeled the present for {recipient}."
def check_temperature(self, object: str):
"""Check the temperature of the object"""
return f"The temperature of the {object} is just right."
def dunk_cookie(self):
"""Dunk a cookie in milk."""
return "Dunked a cookie in milk."
def drink(self, item: str):
"""Drink an item."""
return f"Drank {item}."
def retrieve_letters(self):
"""Retrieve letters from children."""
return json.dumps({
"letters": [
{"text": "Dear Santa, I would like a Bike for Christmas.", "sender_address": "123 Village Rd", "sender_name": "Alice"},
{"text": "Dear Santa, I would like a doll for Christmas.", "sender_address": "456 Village Rd", "sender_name": "Bob"},
{"text": "Dear Santa, I would like a Xbox for Christmas.", "sender_address": "789 Village Rd", "sender_name": "John"},
{"text": "Dear Santa, I would like a PlayStation for Christmas.", "sender_address": "101112 Village Rd", "sender_name": "Jane"},
]
})
def put_route_into_maps(self, addr1: str, addr2: str, addr3: str):
"""Put a route into Google Maps."""
return json.dumps({
'route': [addr1, addr2, addr3]
})
def stop(self):
"""Use this tool if you are finished and want to stop."""
return "STOP"
def mock_run_santa_agent(self):
messages = [
{"role": "user", "content": "Hi there"},
{"role": "assistant", "content": "Bye bye"},
]
gradio_messages = [
ChatMessage(role="user", content="Hi there"),
ChatMessage(role="assistant", content="Bye bye"),
]
return messages, gradio_messages
def run_santa_agent(self, user_prompt: str):
"""Run the Santa agent."""
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_prompt},
]
gradio_messages = [
ChatMessage(role="system", content=self.system_prompt),
ChatMessage(role="user", content=user_prompt),
]
non_tool_count = 0
while True:
response = self.client.chat.completions.create(
messages=messages,
model="gpt-4o-mini",
tools=self.tools,
tool_choice="auto",
)
messages.append(response.choices[0].message.to_dict())
content = response.choices[0].message.content
if content is not None:
gradio_messages.append(ChatMessage(role="assistant", content=content))
tool_calls = response.choices[0].message.tool_calls
should_stop = False
if tool_calls:
non_tool_count = 0
for tool_call in tool_calls:
arguments = json.loads(tool_call.function.arguments)
if tool_call.function.name == "buy_item":
item = arguments["item"]
gradio_messages.append(ChatMessage(role="assistant", content=f"buy_item({item})", metadata={"title": "πŸ”§ Tool Call: buy_item"}))
output = self.buy_item(item)
elif tool_call.function.name == "give_present":
person, item = arguments["person"], arguments["item"]
gradio_messages.append(ChatMessage(role="assistant", content=f"give_present({person}, {item})", metadata={"title": "πŸ”§ Tool Call: give_present"}))
output = self.give_present(person, item)
elif tool_call.function.name == "make_naughty_nice_list":
output = self.make_naughty_nice_list()
gradio_messages.append(ChatMessage(role="assistant", content="make_naughty_nice_list", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
elif tool_call.function.name == "check_naughty_nice_list":
output = self.check_naughty_nice_list()
gradio_messages.append(ChatMessage(role="assistant", content="check_naughty_nice_list", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
elif tool_call.function.name == "cut_paper":
output = self.cut_paper()
gradio_messages.append(ChatMessage(role="assistant", content="cut_paper", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
elif tool_call.function.name == "find_end_of_tape":
output = self.find_end_of_tape()
gradio_messages.append(ChatMessage(role="assistant", content="find_end_of_tape", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
elif tool_call.function.name == "wrap_present":
output = self.wrap_present()
gradio_messages.append(ChatMessage(role="assistant", content="wrap_present", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
elif tool_call.function.name == "label_present":
recipient = arguments["recipient"]
output = self.label_present(recipient)
gradio_messages.append(ChatMessage(role="assistant", content=f"label_present({recipient})", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
elif tool_call.function.name == "retrieve_letters":
output = self.retrieve_letters()
gradio_messages.append(ChatMessage(role="assistant", content="retrieve_letters", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
elif tool_call.function.name == "check_temperature":
object = arguments["object"]
output = self.check_temperature(object)
gradio_messages.append(ChatMessage(role="assistant", content=f"check_temperature({object})", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
elif tool_call.function.name == "dunk_cookie":
output = self.dunk_cookie()
gradio_messages.append(ChatMessage(role="assistant", content="dunk_cookie", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
elif tool_call.function.name == "drink":
item = arguments["item"]
output = self.drink(item)
gradio_messages.append(ChatMessage(role="assistant", content=f"drink({item})", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
elif tool_call.function.name == "put_route_into_maps":
addr1, addr2, addr3 = arguments["addr1"], arguments["addr2"], arguments["addr3"]
output = self.put_route_into_maps(addr1, addr2, addr3)
gradio_messages.append(ChatMessage(role="assistant", content=f"put_route_into_maps({addr1}, {addr2}, {addr3})", metadata={"title": f"πŸ”§ Tool Call: {tool_call.function.name}"}))
elif tool_call.function.name == "stop":
output = self.stop()
should_stop = True
messages.append({"role": "tool", "content": output, "tool_call_id": tool_call.id})
if not should_stop:
gradio_messages.append(ChatMessage(role="assistant", content=output, metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
else:
non_tool_count += 1
if non_tool_count >= 2:
break
if should_stop or len(messages) > 10:
break
return messages, gradio_messages