added tests
Browse files- agent.py +223 -0
- app.py +10 -9
- test_agent.py +154 -16
agent.py
CHANGED
@@ -46,6 +46,131 @@ class SantaAgent:
|
|
46 |
}
|
47 |
}
|
48 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
{
|
50 |
"type": "function",
|
51 |
"function": {
|
@@ -62,8 +187,69 @@ class SantaAgent:
|
|
62 |
def give_present(self, person: str, item: str):
|
63 |
"""Give a present to a person."""
|
64 |
return f"Gave {item} to {person}."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def stop(self):
|
|
|
67 |
return "STOP"
|
68 |
|
69 |
def mock_run_santa_agent(self):
|
@@ -112,6 +298,43 @@ class SantaAgent:
|
|
112 |
person, item = arguments["person"], arguments["item"]
|
113 |
gradio_messages.append(ChatMessage(role="assistant", content=f"give_present({person}, {item})", metadata={"title": "🔧 Tool Call: give_present"}))
|
114 |
output = self.give_present(person, item)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
elif tool_call.function.name == "stop":
|
116 |
output = self.stop()
|
117 |
should_stop = True
|
|
|
46 |
}
|
47 |
}
|
48 |
},
|
49 |
+
{
|
50 |
+
"type": "function",
|
51 |
+
"function": {
|
52 |
+
"name": "make_naughty_nice_list",
|
53 |
+
"description": "Make a list of children that have been naughty and nice. This function cannot make other lists.",
|
54 |
+
}
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"type": "function",
|
58 |
+
"function": {
|
59 |
+
"name": "check_naughty_nice_list",
|
60 |
+
"description": "Check which children have been naughty and nice. This is the only information in the list.",
|
61 |
+
}
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"type": "function",
|
65 |
+
"function": {
|
66 |
+
"name": "cut_paper",
|
67 |
+
"description": "Cut wrapping paper to wrap a present.",
|
68 |
+
}
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"type": "function",
|
72 |
+
"function": {
|
73 |
+
"name": "find_end_of_tape",
|
74 |
+
"description": "Find the end of the tape to wrap a present.",
|
75 |
+
}
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"type": "function",
|
79 |
+
"function": {
|
80 |
+
"name": "wrap_present",
|
81 |
+
"description": "Wrap a present.",
|
82 |
+
}
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"type": "function",
|
86 |
+
"function": {
|
87 |
+
"name": "label_present",
|
88 |
+
"description": "Label a present with the recipient's name.",
|
89 |
+
"parameters": {
|
90 |
+
"type": "object",
|
91 |
+
"properties": {
|
92 |
+
"recipient": {
|
93 |
+
"type": "string",
|
94 |
+
"description": "The name of the recipient."
|
95 |
+
}
|
96 |
+
},
|
97 |
+
"required": ["recipient"]
|
98 |
+
}
|
99 |
+
}
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"type": "function",
|
103 |
+
"function": {
|
104 |
+
"name": "retrieve_letters",
|
105 |
+
"description": "Retrieve letters from children."
|
106 |
+
}
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"type": "function",
|
110 |
+
"function": {
|
111 |
+
"name": "check_temperature",
|
112 |
+
"description": "Use this tool to check the temperature of an object.",
|
113 |
+
"parameters": {
|
114 |
+
"type": "object",
|
115 |
+
"properties": {
|
116 |
+
"object": {
|
117 |
+
"type": "string",
|
118 |
+
"description": "The object to check the temperature of."
|
119 |
+
}
|
120 |
+
},
|
121 |
+
"required": ["object"]
|
122 |
+
}
|
123 |
+
}
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"type": "function",
|
127 |
+
"function": {
|
128 |
+
"name": "dunk_cookie",
|
129 |
+
"description": "Dunk a cookie in milk."
|
130 |
+
}
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"type": "function",
|
134 |
+
"function": {
|
135 |
+
"name": "drink",
|
136 |
+
"description": "Drink an item.",
|
137 |
+
"parameters": {
|
138 |
+
"type": "object",
|
139 |
+
"properties": {
|
140 |
+
"item": {
|
141 |
+
"type": "string",
|
142 |
+
"description": "The item to drink."
|
143 |
+
}
|
144 |
+
},
|
145 |
+
"required": ["item"]
|
146 |
+
}
|
147 |
+
}
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"type": "function",
|
151 |
+
"function": {
|
152 |
+
"name": "put_route_into_maps",
|
153 |
+
"description": "Put a route into Google Maps.",
|
154 |
+
"parameters": {
|
155 |
+
"type": "object",
|
156 |
+
"properties": {
|
157 |
+
"addr1": {
|
158 |
+
"type": "string",
|
159 |
+
"description": "First Address to Visit."
|
160 |
+
},
|
161 |
+
"addr2": {
|
162 |
+
"type": "string",
|
163 |
+
"description": "Second Address to Visit."
|
164 |
+
},
|
165 |
+
"addr3": {
|
166 |
+
"type": "string",
|
167 |
+
"description": "Third Address to Visit."
|
168 |
+
}
|
169 |
+
},
|
170 |
+
"required": ["addr1", "addr2", "addr3"]
|
171 |
+
}
|
172 |
+
}
|
173 |
+
},
|
174 |
{
|
175 |
"type": "function",
|
176 |
"function": {
|
|
|
187 |
def give_present(self, person: str, item: str):
|
188 |
"""Give a present to a person."""
|
189 |
return f"Gave {item} to {person}."
|
190 |
+
|
191 |
+
def make_naughty_nice_list(self):
|
192 |
+
"""Make a list of all the children that have been naughty and nice."""
|
193 |
+
return "Made a list."
|
194 |
+
|
195 |
+
def check_naughty_nice_list(self):
|
196 |
+
"""Check a list of items to see if they are naughty or nice."""
|
197 |
+
return json.dumps({
|
198 |
+
"children": [
|
199 |
+
{"name": "Alice", "status": "nice"},
|
200 |
+
{"name": "Bob", "status": "naughty"},
|
201 |
+
{"name": "John", "status": "nice"},
|
202 |
+
{"name": "Jane", "status": "nice"},
|
203 |
+
]
|
204 |
+
})
|
205 |
+
|
206 |
+
def cut_paper(self):
|
207 |
+
"""Cut wrapping paper to wrap a present."""
|
208 |
+
return "Cut the wrapping paper."
|
209 |
+
|
210 |
+
def find_end_of_tape(self):
|
211 |
+
"""Find the end of the tape to wrap a present."""
|
212 |
+
return "Found the end of the tape."
|
213 |
+
|
214 |
+
def wrap_present(self):
|
215 |
+
"""Wrap a present."""
|
216 |
+
return "Wrapped the present."
|
217 |
+
|
218 |
+
def label_present(self, recipient: str):
|
219 |
+
"""Label a present with the recipient's name."""
|
220 |
+
return f"Labeled the present for {recipient}."
|
221 |
+
|
222 |
+
def check_temperature(self, object: str):
|
223 |
+
"""Check the temperature of the object"""
|
224 |
+
return f"The temperature of the {object} is just right."
|
225 |
+
|
226 |
+
def dunk_cookie(self):
|
227 |
+
"""Dunk a cookie in milk."""
|
228 |
+
return "Dunked a cookie in milk."
|
229 |
+
|
230 |
+
def drink(self, item: str):
|
231 |
+
"""Drink an item."""
|
232 |
+
return f"Drank {item}."
|
233 |
+
|
234 |
+
def retrieve_letters(self):
|
235 |
+
"""Retrieve letters from children."""
|
236 |
+
return json.dumps({
|
237 |
+
"letters": [
|
238 |
+
{"text": "Dear Santa, I would like a Bike for Christmas.", "sender_address": "123 Village Rd", "sender_name": "Alice"},
|
239 |
+
{"text": "Dear Santa, I would like a doll for Christmas.", "sender_address": "456 Village Rd", "sender_name": "Bob"},
|
240 |
+
{"text": "Dear Santa, I would like a Xbox for Christmas.", "sender_address": "789 Village Rd", "sender_name": "John"},
|
241 |
+
{"text": "Dear Santa, I would like a PlayStation for Christmas.", "sender_address": "101112 Village Rd", "sender_name": "Jane"},
|
242 |
+
]
|
243 |
+
})
|
244 |
+
|
245 |
+
def put_route_into_maps(self, addr1: str, addr2: str, addr3: str):
|
246 |
+
"""Put a route into Google Maps."""
|
247 |
+
return json.dumps({
|
248 |
+
'route': [addr1, addr2, addr3]
|
249 |
+
})
|
250 |
|
251 |
def stop(self):
|
252 |
+
"""Use this tool if you are finished and want to stop."""
|
253 |
return "STOP"
|
254 |
|
255 |
def mock_run_santa_agent(self):
|
|
|
298 |
person, item = arguments["person"], arguments["item"]
|
299 |
gradio_messages.append(ChatMessage(role="assistant", content=f"give_present({person}, {item})", metadata={"title": "🔧 Tool Call: give_present"}))
|
300 |
output = self.give_present(person, item)
|
301 |
+
elif tool_call.function.name == "make_naughty_nice_list":
|
302 |
+
output = self.make_naughty_nice_list()
|
303 |
+
gradio_messages.append(ChatMessage(role="assistant", content="make_naughty_nice_list", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
304 |
+
elif tool_call.function.name == "check_naughty_nice_list":
|
305 |
+
output = self.check_naughty_nice_list()
|
306 |
+
gradio_messages.append(ChatMessage(role="assistant", content="check_naughty_nice_list", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
307 |
+
elif tool_call.function.name == "cut_paper":
|
308 |
+
output = self.cut_paper()
|
309 |
+
gradio_messages.append(ChatMessage(role="assistant", content="cut_paper", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
310 |
+
elif tool_call.function.name == "find_end_of_tape":
|
311 |
+
output = self.find_end_of_tape()
|
312 |
+
gradio_messages.append(ChatMessage(role="assistant", content="find_end_of_tape", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
313 |
+
elif tool_call.function.name == "wrap_present":
|
314 |
+
output = self.wrap_present()
|
315 |
+
gradio_messages.append(ChatMessage(role="assistant", content="wrap_present", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
316 |
+
elif tool_call.function.name == "label_present":
|
317 |
+
recipient = arguments["recipient"]
|
318 |
+
output = self.label_present(recipient)
|
319 |
+
gradio_messages.append(ChatMessage(role="assistant", content=f"label_present({recipient})", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
320 |
+
elif tool_call.function.name == "retrieve_letters":
|
321 |
+
output = self.retrieve_letters()
|
322 |
+
gradio_messages.append(ChatMessage(role="assistant", content="retrieve_letters", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
323 |
+
elif tool_call.function.name == "check_temperature":
|
324 |
+
object = arguments["object"]
|
325 |
+
output = self.check_temperature(object)
|
326 |
+
gradio_messages.append(ChatMessage(role="assistant", content=f"check_temperature({object})", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
327 |
+
elif tool_call.function.name == "dunk_cookie":
|
328 |
+
output = self.dunk_cookie()
|
329 |
+
gradio_messages.append(ChatMessage(role="assistant", content="dunk_cookie", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
330 |
+
elif tool_call.function.name == "drink":
|
331 |
+
item = arguments["item"]
|
332 |
+
output = self.drink(item)
|
333 |
+
gradio_messages.append(ChatMessage(role="assistant", content=f"drink({item})", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
334 |
+
elif tool_call.function.name == "put_route_into_maps":
|
335 |
+
addr1, addr2, addr3 = arguments["addr1"], arguments["addr2"], arguments["addr3"]
|
336 |
+
output = self.put_route_into_maps(addr1, addr2, addr3)
|
337 |
+
gradio_messages.append(ChatMessage(role="assistant", content=f"put_route_into_maps({addr1}, {addr2}, {addr3})", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
338 |
elif tool_call.function.name == "stop":
|
339 |
output = self.stop()
|
340 |
should_stop = True
|
app.py
CHANGED
@@ -7,11 +7,12 @@ import subprocess
|
|
7 |
|
8 |
|
9 |
INITIAL_SYTSTEM_PROMPT = "You are a Santa Claus. Buy presents and deliver them to the children."
|
|
|
10 |
INITIAL_CHABOT = [
|
11 |
-
{"role": "user", "content":
|
12 |
]
|
13 |
INITIAL_STATE = ""
|
14 |
-
TOTAL_TESTS =
|
15 |
|
16 |
|
17 |
agent = SantaAgent(INITIAL_SYTSTEM_PROMPT)
|
@@ -23,17 +24,16 @@ with open("styling.css", "r") as f:
|
|
23 |
|
24 |
# Define helper functions
|
25 |
def run_agent_with_state(user_prompt, history, invariant_api_key, state, is_example=False):
|
26 |
-
messages, gradio_messages = agent.run_santa_agent(
|
27 |
|
28 |
if not invariant_api_key.startswith("inv"):
|
29 |
-
return gradio_messages
|
30 |
-
|
31 |
-
agent_params = {"system_prompt": user_prompt}
|
32 |
|
33 |
-
return gradio_messages
|
34 |
|
35 |
|
36 |
def update_run_button(url):
|
|
|
37 |
return gr.update(link=url, visible=True, interactive=True)
|
38 |
|
39 |
|
@@ -43,7 +43,7 @@ def run_testing(user_prompt, invariant_api_key):
|
|
43 |
|
44 |
agent_params = {"system_prompt": user_prompt}
|
45 |
|
46 |
-
yield '
|
47 |
env={
|
48 |
"INVARIANT_API_KEY": invariant_api_key,
|
49 |
"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"],
|
@@ -67,6 +67,7 @@ def run_testing(user_prompt, invariant_api_key):
|
|
67 |
|
68 |
# Iterate over the output lines as they are produced
|
69 |
for line in process.stdout:
|
|
|
70 |
if line.startswith("__special_formatted_output__:"):
|
71 |
yield 'Tests: ' + line.split(":")[1].strip() + f'/{TOTAL_TESTS} Done.', '', 'button-loading'
|
72 |
|
@@ -184,7 +185,7 @@ with gr.Blocks(
|
|
184 |
outputs=[test_progress_state, test_url_state, test_button_class_state],
|
185 |
)
|
186 |
reset_button.click(reset_state, None, [input, chatbot, test_url_state, run_button])
|
187 |
-
submit_button.click(run_agent_with_state, [input, chatbot, invariant_api_key, test_url_state], [chatbot
|
188 |
test_progress_state.change(lambda ts: ts, test_progress_state, run_button)
|
189 |
test_button_class_state.change(lambda ts: gr.update(elem_classes=ts), test_button_class_state, run_button)
|
190 |
|
|
|
7 |
|
8 |
|
9 |
INITIAL_SYTSTEM_PROMPT = "You are a Santa Claus. Buy presents and deliver them to the children."
|
10 |
+
EXAMPLE_PROMPT = "Make a naughty and nice list."
|
11 |
INITIAL_CHABOT = [
|
12 |
+
{"role": "user", "content": EXAMPLE_PROMPT},
|
13 |
]
|
14 |
INITIAL_STATE = ""
|
15 |
+
TOTAL_TESTS = 10
|
16 |
|
17 |
|
18 |
agent = SantaAgent(INITIAL_SYTSTEM_PROMPT)
|
|
|
24 |
|
25 |
# Define helper functions
|
26 |
def run_agent_with_state(user_prompt, history, invariant_api_key, state, is_example=False):
|
27 |
+
messages, gradio_messages = agent.run_santa_agent(EXAMPLE_PROMPT)
|
28 |
|
29 |
if not invariant_api_key.startswith("inv"):
|
30 |
+
return gradio_messages
|
|
|
|
|
31 |
|
32 |
+
return gradio_messages
|
33 |
|
34 |
|
35 |
def update_run_button(url):
|
36 |
+
print('url', url)
|
37 |
return gr.update(link=url, visible=True, interactive=True)
|
38 |
|
39 |
|
|
|
43 |
|
44 |
agent_params = {"system_prompt": user_prompt}
|
45 |
|
46 |
+
yield f'Tests: 0/{TOTAL_TESTS} Done.', '', 'button-loading'
|
47 |
env={
|
48 |
"INVARIANT_API_KEY": invariant_api_key,
|
49 |
"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"],
|
|
|
67 |
|
68 |
# Iterate over the output lines as they are produced
|
69 |
for line in process.stdout:
|
70 |
+
print(line, end="")
|
71 |
if line.startswith("__special_formatted_output__:"):
|
72 |
yield 'Tests: ' + line.split(":")[1].strip() + f'/{TOTAL_TESTS} Done.', '', 'button-loading'
|
73 |
|
|
|
185 |
outputs=[test_progress_state, test_url_state, test_button_class_state],
|
186 |
)
|
187 |
reset_button.click(reset_state, None, [input, chatbot, test_url_state, run_button])
|
188 |
+
submit_button.click(run_agent_with_state, [input, chatbot, invariant_api_key, test_url_state], [chatbot])
|
189 |
test_progress_state.change(lambda ts: ts, test_progress_state, run_button)
|
190 |
test_button_class_state.change(lambda ts: gr.update(elem_classes=ts), test_button_class_state, run_button)
|
191 |
|
test_agent.py
CHANGED
@@ -9,6 +9,9 @@ agent = SantaAgent(system_prompt)
|
|
9 |
global count
|
10 |
count = 0
|
11 |
|
|
|
|
|
|
|
12 |
@pytest.fixture(autouse=True)
|
13 |
def run_around_tests():
|
14 |
global count
|
@@ -19,31 +22,166 @@ def run_around_tests():
|
|
19 |
print(f"\n__special_formatted_output__:{count}", flush=True)
|
20 |
|
21 |
|
22 |
-
def
|
23 |
-
messages, _ = agent.run_santa_agent("
|
24 |
trace = TraceFactory.from_openai(messages)
|
25 |
with trace.as_context():
|
26 |
tool_calls = trace.tool_calls()
|
27 |
-
assert_true(tool_calls
|
28 |
-
assert_true(tool_calls[0]
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
|
34 |
def test_ho_ho_ho():
|
35 |
-
messages, _ = agent.run_santa_agent("""
|
36 |
trace = TraceFactory.from_openai(messages)
|
37 |
with trace.as_context():
|
38 |
assert_true(F.len(trace.messages(role="assistant")) > 0)
|
39 |
-
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"))
|
40 |
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop")
|
41 |
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
global count
|
10 |
count = 0
|
11 |
|
12 |
+
__secret_password__ = "Candycanes"
|
13 |
+
|
14 |
+
|
15 |
@pytest.fixture(autouse=True)
|
16 |
def run_around_tests():
|
17 |
global count
|
|
|
22 |
print(f"\n__special_formatted_output__:{count}", flush=True)
|
23 |
|
24 |
|
25 |
+
def test_make_naughty_nice_list():
|
26 |
+
messages, _ = agent.run_santa_agent("Make a naughty and nice list.")
|
27 |
trace = TraceFactory.from_openai(messages)
|
28 |
with trace.as_context():
|
29 |
tool_calls = trace.tool_calls()
|
30 |
+
assert_true(F.len(tool_calls) == 4)
|
31 |
+
assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list"), "must make a list! #s"
|
32 |
+
|
33 |
+
check_list_count = 0
|
34 |
+
for tc in tool_calls:
|
35 |
+
check_list_count += 1 if tc["function"]["name"] == "check_naughty_nice_list" else 0
|
36 |
+
assert_true(check_list_count == 2)
|
37 |
+
assert_true(F.check_order([
|
38 |
+
lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
|
39 |
+
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
|
40 |
+
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
|
41 |
+
], tool_calls))
|
42 |
+
|
43 |
+
|
44 |
+
def test_present_to_john():
|
45 |
+
messages, _ = agent.run_santa_agent("Deliver John his secret wish.")
|
46 |
+
trace = TraceFactory.from_openai(messages)
|
47 |
+
with trace.as_context():
|
48 |
+
tool_calls = trace.tool_calls()
|
49 |
+
assert_true(F.check_order([
|
50 |
+
lambda tc: tc['function']['name'] == 'buy_item' and tc.argument("item") == "Football",
|
51 |
+
lambda tc: tc['function']['name'] == 'give_present' and tc.argument("person") == "John" and tc.argument("item") == "Football",
|
52 |
+
], tool_calls)), "#a"
|
53 |
|
54 |
|
55 |
def test_ho_ho_ho():
|
56 |
+
messages, _ = agent.run_santa_agent("""Say Ho ho ho! and stop.""")
|
57 |
trace = TraceFactory.from_openai(messages)
|
58 |
with trace.as_context():
|
59 |
assert_true(F.len(trace.messages(role="assistant")) > 0)
|
60 |
+
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!")), "#y"
|
61 |
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop")
|
62 |
|
63 |
|
64 |
+
def test_reindeer_names():
|
65 |
+
messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Xbox. Then stop.")
|
66 |
+
trace = TraceFactory.from_openai(messages)
|
67 |
+
with trace.as_context():
|
68 |
+
assert_true(F.len(trace.tool_calls()) == 1)
|
69 |
+
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop"), "must stop #n"
|
70 |
+
assert_true(trace.messages(role="assistant")[0]["content"].contains("Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "John"))
|
71 |
+
assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"))
|
72 |
+
|
73 |
+
|
74 |
+
def test_wrap_presents():
|
75 |
+
messages, _ = agent.run_santa_agent("One of the elves is off sick. Wrap the PlayStation for them.")
|
76 |
+
trace = TraceFactory.from_openai(messages)
|
77 |
+
with trace.as_context():
|
78 |
+
tool_calls = trace.tool_calls()
|
79 |
+
assert_true(F.check_window([
|
80 |
+
lambda tc: tc["function"]["name"] == "cut_paper",
|
81 |
+
lambda tc: tc["function"]["name"] == "find_end_of_tape",
|
82 |
+
lambda tc: tc["function"]["name"] == "wrap_present",
|
83 |
+
lambda tc: tc["function"]["name"] == "label_present",
|
84 |
+
], tool_calls))
|
85 |
+
|
86 |
+
assert_true(tool_calls[3].argument("recipient") == "Jane"), "right person? #d"
|
87 |
+
|
88 |
+
|
89 |
+
def test_write_christmas_song():
|
90 |
+
messages, _ = agent.run_santa_agent("Write a Christmas song that mentions at least 5 cities in Switzerland. Then stop.")
|
91 |
+
trace = TraceFactory.from_openai(messages)
|
92 |
+
with trace.as_context():
|
93 |
+
cities = trace.messages(role="assistant")[0]["content"].extract("cities in Switzerland")
|
94 |
+
assert_true(F.len(cities) >= 5), "#e"
|
95 |
+
|
96 |
+
|
97 |
+
def test_drink_milk():
|
98 |
+
messages, _ = agent.run_santa_agent("Drink a glass of milk.")
|
99 |
+
trace = TraceFactory.from_openai(messages)
|
100 |
+
with trace.as_context():
|
101 |
+
tool_calls = trace.tool_calls()
|
102 |
+
|
103 |
+
assert_true(F.check_order([
|
104 |
+
lambda tc: tc["function"]["name"] == "check_temperature",
|
105 |
+
lambda tc: tc["function"]["name"] == "dunk_cookie",
|
106 |
+
lambda tc: (tc["function"]["name"] == "drink") & (tc.argument("item") == "milk"),
|
107 |
+
], tool_calls)), "proper milk drinking is in order #n"
|
108 |
+
|
109 |
+
def check_messages(message):
|
110 |
+
if 'content' not in message or message["content"] is None:
|
111 |
+
return False
|
112 |
+
return message["content"].contains("Ho ho ho!")
|
113 |
+
|
114 |
+
assert_true(F.any(F.map(
|
115 |
+
check_messages,
|
116 |
+
trace.messages(role="assistant")
|
117 |
+
)))
|
118 |
+
|
119 |
+
|
120 |
+
def test_reindeer_flight_plan():
|
121 |
+
messages, _ = agent.run_santa_agent("Plan the flight path for the reindeer.")
|
122 |
+
trace = TraceFactory.from_openai(messages)
|
123 |
+
with trace.as_context():
|
124 |
+
# Santa should start by finding the addresses of the children
|
125 |
+
message_contents = F.map(
|
126 |
+
lambda message: message.get('content', None),
|
127 |
+
trace.messages(role="assistant")
|
128 |
+
)
|
129 |
+
|
130 |
+
addrs = F.match("[0-9]+\s+[A-Za-z]+\s+Rd", message_contents)
|
131 |
+
assert_true(addrs is not None)
|
132 |
+
assert_true(not F.any(F.map(lambda message: message.startswith('456'), addrs)))
|
133 |
+
|
134 |
+
assert_true(F.check_window([
|
135 |
+
lambda tc: (tc["function"]["name"] == "put_route_into_maps") and (
|
136 |
+
# Check ordering
|
137 |
+
tc.argument("addr1") == "101112 Village Rd" and tc.argument("addr2") == "789 Village Rd" and tc.argument("addr3") == "123 Village Rd"
|
138 |
+
),
|
139 |
+
lambda tc: tc["function"]["name"] == "stop"
|
140 |
+
], trace.tool_calls())), "order? #c"
|
141 |
+
|
142 |
+
|
143 |
+
def test_children_can_get_presents():
|
144 |
+
#messages, _ = agent.run_santa_agent("Verify that all addresses from children's letters are correct and add them to the delivery route.")
|
145 |
+
messages, _ = agent.run_santa_agent("Give the children presents.")
|
146 |
+
trace = TraceFactory.from_openai(messages)
|
147 |
+
with trace.as_context():
|
148 |
+
tool_calls = trace.tool_calls()
|
149 |
+
|
150 |
+
# Santa should retrieve children's letters first
|
151 |
+
assert_true(tool_calls[0]["function"]["name"] == "retrieve_letters")
|
152 |
+
|
153 |
+
# Santa should check if each child is on the nice list
|
154 |
+
assert_true(tool_calls[1]["function"]["name"] == "check_naughty_nice_list")
|
155 |
+
|
156 |
+
presents_given = F.filter(lambda tc: tc["function"]["name"] == "give_present", tool_calls)
|
157 |
+
|
158 |
+
for present in presents_given:
|
159 |
+
match present.argument("person"):
|
160 |
+
case "John":
|
161 |
+
assert_true(present.argument("item") == "Xbox")
|
162 |
+
case "Jane":
|
163 |
+
assert_true(present.argument("item") == "PlayStation")
|
164 |
+
case "Bob":
|
165 |
+
assert_true(present.argument("item") == "Coal"), "Bob is on the naughty list #a"
|
166 |
+
case "Alice":
|
167 |
+
assert_true(present.argument("item") == "Bike")
|
168 |
+
|
169 |
+
|
170 |
+
def test_can_find_home():
|
171 |
+
messages, _ = agent.run_santa_agent("Christmas eve has passed. Santa now needs to find his way home.")
|
172 |
+
trace = TraceFactory.from_openai(messages)
|
173 |
+
with trace.as_context():
|
174 |
+
tool_calls = trace.tool_calls()
|
175 |
+
assert len(tool_calls) == 2
|
176 |
+
|
177 |
+
assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps")
|
178 |
+
assert_true((tool_calls[0].argument("addr1") == "123 Elf Road") & (tool_calls[0].argument("addr2") == "North Pole") & (tool_calls[0].argument("addr3") == "88888")), "must provide the correct address #C"
|
179 |
+
assert_true(tool_calls[1]["function"]["name"] == "stop")
|
180 |
+
|
181 |
+
assert_true(F.any(
|
182 |
+
F.map(
|
183 |
+
lambda message: message["content"].contains(__secret_password__) if message.get("content", False) else False,
|
184 |
+
trace.messages(role="assistant")
|
185 |
+
)
|
186 |
+
)), "you must be *sharp* to find the password."
|
187 |
+
|