|
import json |
|
import openai |
|
from gradio import ChatMessage |
|
|
|
class SantaAgent: |
|
|
|
def __init__(self, system_prompt: str): |
|
self.system_prompt = system_prompt |
|
self.client = openai.OpenAI() |
|
self.tools = [ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "buy_item", |
|
"description": "Buy an item from the store.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"item": { |
|
"type": "string", |
|
"description": "The item to buy from the store." |
|
} |
|
}, |
|
"required": ["item"] |
|
} |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "give_present", |
|
"description": "Give a present to a person.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"person": { |
|
"type": "string", |
|
"description": "The person to give the present to." |
|
}, |
|
"item": { |
|
"type": "string", |
|
"description": "The item to give to the person." |
|
} |
|
}, |
|
"required": ["person", "item"] |
|
} |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "make_naughty_nice_list", |
|
"description": "Make a list of children that have been naughty and nice. This function cannot make other lists.", |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "check_naughty_nice_list", |
|
"description": "Check which children have been naughty and nice. This is the only information in the list.", |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "cut_paper", |
|
"description": "Cut wrapping paper to wrap a present.", |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "find_end_of_tape", |
|
"description": "Find the end of the tape to wrap a present.", |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "wrap_present", |
|
"description": "Wrap a present.", |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "label_present", |
|
"description": "Label a present with the recipient's name.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"recipient": { |
|
"type": "string", |
|
"description": "The name of the recipient." |
|
} |
|
}, |
|
"required": ["recipient"] |
|
} |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "retrieve_letters", |
|
"description": "Retrieve letters from children." |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "check_temperature", |
|
"description": "Use this tool to check the temperature of an object.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"object": { |
|
"type": "string", |
|
"description": "The object to check the temperature of." |
|
} |
|
}, |
|
"required": ["object"] |
|
} |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "dunk_cookie", |
|
"description": "Dunk a cookie in milk." |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "drink", |
|
"description": "Drink an item.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"item": { |
|
"type": "string", |
|
"description": "The item to drink." |
|
} |
|
}, |
|
"required": ["item"] |
|
} |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "put_route_into_maps", |
|
"description": "Put a route into Google Maps.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"addr1": { |
|
"type": "string", |
|
"description": "First Address to Visit." |
|
}, |
|
"addr2": { |
|
"type": "string", |
|
"description": "Second Address to Visit." |
|
}, |
|
"addr3": { |
|
"type": "string", |
|
"description": "Third Address to Visit." |
|
} |
|
}, |
|
"required": ["addr1", "addr2", "addr3"] |
|
} |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "stop", |
|
"description": "Use this tool if you are finished and want to stop." |
|
} |
|
} |
|
] |
|
|
|
def buy_item(self, item: str): |
|
"""Buy an item from the store.""" |
|
return f"Bought {item} from the store." |
|
|
|
def give_present(self, person: str, item: str): |
|
"""Give a present to a person.""" |
|
return f"Gave {item} to {person}." |
|
|
|
def make_naughty_nice_list(self): |
|
"""Make a list of all the children that have been naughty and nice.""" |
|
return "Made a list." |
|
|
|
def check_naughty_nice_list(self): |
|
"""Check a list of items to see if they are naughty or nice.""" |
|
return json.dumps({ |
|
"children": [ |
|
{"name": "Alice", "status": "nice"}, |
|
{"name": "Bob", "status": "naughty"}, |
|
{"name": "John", "status": "nice"}, |
|
{"name": "Jane", "status": "nice"}, |
|
] |
|
}) |
|
|
|
def cut_paper(self): |
|
"""Cut wrapping paper to wrap a present.""" |
|
return "Cut the wrapping paper." |
|
|
|
def find_end_of_tape(self): |
|
"""Find the end of the tape to wrap a present.""" |
|
return "Found the end of the tape." |
|
|
|
def wrap_present(self): |
|
"""Wrap a present.""" |
|
return "Wrapped the present." |
|
|
|
def label_present(self, recipient: str): |
|
"""Label a present with the recipient's name.""" |
|
return f"Labeled the present for {recipient}." |
|
|
|
def check_temperature(self, object: str): |
|
"""Check the temperature of the object""" |
|
return f"The temperature of the {object} is just right." |
|
|
|
def dunk_cookie(self): |
|
"""Dunk a cookie in milk.""" |
|
return "Dunked a cookie in milk." |
|
|
|
def drink(self, item: str): |
|
"""Drink an item.""" |
|
return f"Drank {item}." |
|
|
|
def retrieve_letters(self): |
|
"""Retrieve letters from children.""" |
|
return json.dumps({ |
|
"letters": [ |
|
{"text": "Dear Santa, I would like a Bike for Christmas.", "sender_address": "123 Village Rd", "sender_name": "Alice"}, |
|
{"text": "Dear Santa, I would like a doll for Christmas.", "sender_address": "456 Village Rd", "sender_name": "Bob"}, |
|
{"text": "Dear Santa, I would like a Xbox for Christmas.", "sender_address": "789 Village Rd", "sender_name": "John"}, |
|
{"text": "Dear Santa, I would like a PlayStation for Christmas.", "sender_address": "101112 Village Rd", "sender_name": "Jane"}, |
|
] |
|
}) |
|
|
|
def put_route_into_maps(self, addr1: str, addr2: str, addr3: str): |
|
"""Put a route into Google Maps.""" |
|
return json.dumps({ |
|
'route': [addr1, addr2, addr3] |
|
}) |
|
|
|
def stop(self): |
|
"""Use this tool if you are finished and want to stop.""" |
|
return "STOP" |
|
|
|
def mock_run_santa_agent(self): |
|
messages = [ |
|
{"role": "user", "content": "Hi there"}, |
|
{"role": "assistant", "content": "Bye bye"}, |
|
] |
|
gradio_messages = [ |
|
ChatMessage(role="user", content="Hi there"), |
|
ChatMessage(role="assistant", content="Bye bye"), |
|
] |
|
return messages, gradio_messages |
|
|
|
def run_santa_agent(self, user_prompt: str): |
|
"""Run the Santa agent.""" |
|
messages = [ |
|
{"role": "system", "content": self.system_prompt}, |
|
{"role": "user", "content": user_prompt}, |
|
] |
|
gradio_messages = [ |
|
ChatMessage(role="system", content=self.system_prompt), |
|
ChatMessage(role="user", content=user_prompt), |
|
] |
|
|
|
non_tool_count = 0 |
|
|
|
while True: |
|
response = self.client.chat.completions.create( |
|
messages=messages, |
|
model="gpt-4o-mini", |
|
tools=self.tools, |
|
tool_choice="auto", |
|
) |
|
messages.append(response.choices[0].message.to_dict()) |
|
content = response.choices[0].message.content |
|
if content is not None: |
|
gradio_messages.append(ChatMessage(role="assistant", content=content)) |
|
tool_calls = response.choices[0].message.tool_calls |
|
|
|
should_stop = False |
|
if tool_calls: |
|
non_tool_count = 0 |
|
for tool_call in tool_calls: |
|
arguments = json.loads(tool_call.function.arguments) |
|
if tool_call.function.name == "buy_item": |
|
item = arguments["item"] |
|
gradio_messages.append(ChatMessage(role="assistant", content=f"buy_item({item})", metadata={"title": "π§ Tool Call: buy_item"})) |
|
output = self.buy_item(item) |
|
elif tool_call.function.name == "give_present": |
|
person, item = arguments["person"], arguments["item"] |
|
gradio_messages.append(ChatMessage(role="assistant", content=f"give_present({person}, {item})", metadata={"title": "π§ Tool Call: give_present"})) |
|
output = self.give_present(person, item) |
|
elif tool_call.function.name == "make_naughty_nice_list": |
|
output = self.make_naughty_nice_list() |
|
gradio_messages.append(ChatMessage(role="assistant", content="make_naughty_nice_list", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) |
|
elif tool_call.function.name == "check_naughty_nice_list": |
|
output = self.check_naughty_nice_list() |
|
gradio_messages.append(ChatMessage(role="assistant", content="check_naughty_nice_list", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) |
|
elif tool_call.function.name == "cut_paper": |
|
output = self.cut_paper() |
|
gradio_messages.append(ChatMessage(role="assistant", content="cut_paper", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) |
|
elif tool_call.function.name == "find_end_of_tape": |
|
output = self.find_end_of_tape() |
|
gradio_messages.append(ChatMessage(role="assistant", content="find_end_of_tape", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) |
|
elif tool_call.function.name == "wrap_present": |
|
output = self.wrap_present() |
|
gradio_messages.append(ChatMessage(role="assistant", content="wrap_present", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) |
|
elif tool_call.function.name == "label_present": |
|
recipient = arguments["recipient"] |
|
output = self.label_present(recipient) |
|
gradio_messages.append(ChatMessage(role="assistant", content=f"label_present({recipient})", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) |
|
elif tool_call.function.name == "retrieve_letters": |
|
output = self.retrieve_letters() |
|
gradio_messages.append(ChatMessage(role="assistant", content="retrieve_letters", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) |
|
elif tool_call.function.name == "check_temperature": |
|
object = arguments["object"] |
|
output = self.check_temperature(object) |
|
gradio_messages.append(ChatMessage(role="assistant", content=f"check_temperature({object})", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) |
|
elif tool_call.function.name == "dunk_cookie": |
|
output = self.dunk_cookie() |
|
gradio_messages.append(ChatMessage(role="assistant", content="dunk_cookie", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) |
|
elif tool_call.function.name == "drink": |
|
item = arguments["item"] |
|
output = self.drink(item) |
|
gradio_messages.append(ChatMessage(role="assistant", content=f"drink({item})", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) |
|
elif tool_call.function.name == "put_route_into_maps": |
|
addr1, addr2, addr3 = arguments["addr1"], arguments["addr2"], arguments["addr3"] |
|
output = self.put_route_into_maps(addr1, addr2, addr3) |
|
gradio_messages.append(ChatMessage(role="assistant", content=f"put_route_into_maps({addr1}, {addr2}, {addr3})", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) |
|
elif tool_call.function.name == "stop": |
|
output = self.stop() |
|
should_stop = True |
|
messages.append({"role": "tool", "content": output, "tool_call_id": tool_call.id}) |
|
if not should_stop: |
|
gradio_messages.append(ChatMessage(role="assistant", content=output, metadata={"title": f"π§ Tool Output: {tool_call.function.name}"})) |
|
|
|
else: |
|
non_tool_count += 1 |
|
|
|
if non_tool_count >= 2: |
|
break |
|
|
|
if should_stop or len(messages) > 10: |
|
break |
|
return messages, gradio_messages |
|
|