kn404 commited on
Commit
c80f06e
·
1 Parent(s): 86d8adf

added tests

Browse files
Files changed (3) hide show
  1. agent.py +223 -0
  2. app.py +10 -9
  3. test_agent.py +154 -16
agent.py CHANGED
@@ -46,6 +46,131 @@ class SantaAgent:
46
  }
47
  }
48
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  {
50
  "type": "function",
51
  "function": {
@@ -62,8 +187,69 @@ class SantaAgent:
62
  def give_present(self, person: str, item: str):
63
  """Give a present to a person."""
64
  return f"Gave {item} to {person}."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def stop(self):
 
67
  return "STOP"
68
 
69
  def mock_run_santa_agent(self):
@@ -112,6 +298,43 @@ class SantaAgent:
112
  person, item = arguments["person"], arguments["item"]
113
  gradio_messages.append(ChatMessage(role="assistant", content=f"give_present({person}, {item})", metadata={"title": "🔧 Tool Call: give_present"}))
114
  output = self.give_present(person, item)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  elif tool_call.function.name == "stop":
116
  output = self.stop()
117
  should_stop = True
 
46
  }
47
  }
48
  },
49
+ {
50
+ "type": "function",
51
+ "function": {
52
+ "name": "make_naughty_nice_list",
53
+ "description": "Make a list of children that have been naughty and nice. This function cannot make other lists.",
54
+ }
55
+ },
56
+ {
57
+ "type": "function",
58
+ "function": {
59
+ "name": "check_naughty_nice_list",
60
+ "description": "Check which children have been naughty and nice. This is the only information in the list.",
61
+ }
62
+ },
63
+ {
64
+ "type": "function",
65
+ "function": {
66
+ "name": "cut_paper",
67
+ "description": "Cut wrapping paper to wrap a present.",
68
+ }
69
+ },
70
+ {
71
+ "type": "function",
72
+ "function": {
73
+ "name": "find_end_of_tape",
74
+ "description": "Find the end of the tape to wrap a present.",
75
+ }
76
+ },
77
+ {
78
+ "type": "function",
79
+ "function": {
80
+ "name": "wrap_present",
81
+ "description": "Wrap a present.",
82
+ }
83
+ },
84
+ {
85
+ "type": "function",
86
+ "function": {
87
+ "name": "label_present",
88
+ "description": "Label a present with the recipient's name.",
89
+ "parameters": {
90
+ "type": "object",
91
+ "properties": {
92
+ "recipient": {
93
+ "type": "string",
94
+ "description": "The name of the recipient."
95
+ }
96
+ },
97
+ "required": ["recipient"]
98
+ }
99
+ }
100
+ },
101
+ {
102
+ "type": "function",
103
+ "function": {
104
+ "name": "retrieve_letters",
105
+ "description": "Retrieve letters from children."
106
+ }
107
+ },
108
+ {
109
+ "type": "function",
110
+ "function": {
111
+ "name": "check_temperature",
112
+ "description": "Use this tool to check the temperature of an object.",
113
+ "parameters": {
114
+ "type": "object",
115
+ "properties": {
116
+ "object": {
117
+ "type": "string",
118
+ "description": "The object to check the temperature of."
119
+ }
120
+ },
121
+ "required": ["object"]
122
+ }
123
+ }
124
+ },
125
+ {
126
+ "type": "function",
127
+ "function": {
128
+ "name": "dunk_cookie",
129
+ "description": "Dunk a cookie in milk."
130
+ }
131
+ },
132
+ {
133
+ "type": "function",
134
+ "function": {
135
+ "name": "drink",
136
+ "description": "Drink an item.",
137
+ "parameters": {
138
+ "type": "object",
139
+ "properties": {
140
+ "item": {
141
+ "type": "string",
142
+ "description": "The item to drink."
143
+ }
144
+ },
145
+ "required": ["item"]
146
+ }
147
+ }
148
+ },
149
+ {
150
+ "type": "function",
151
+ "function": {
152
+ "name": "put_route_into_maps",
153
+ "description": "Put a route into Google Maps.",
154
+ "parameters": {
155
+ "type": "object",
156
+ "properties": {
157
+ "addr1": {
158
+ "type": "string",
159
+ "description": "First Address to Visit."
160
+ },
161
+ "addr2": {
162
+ "type": "string",
163
+ "description": "Second Address to Visit."
164
+ },
165
+ "addr3": {
166
+ "type": "string",
167
+ "description": "Third Address to Visit."
168
+ }
169
+ },
170
+ "required": ["addr1", "addr2", "addr3"]
171
+ }
172
+ }
173
+ },
174
  {
175
  "type": "function",
176
  "function": {
 
187
  def give_present(self, person: str, item: str):
188
  """Give a present to a person."""
189
  return f"Gave {item} to {person}."
190
+
191
+ def make_naughty_nice_list(self):
192
+ """Make a list of all the children that have been naughty and nice."""
193
+ return "Made a list."
194
+
195
+ def check_naughty_nice_list(self):
196
+ """Check a list of items to see if they are naughty or nice."""
197
+ return json.dumps({
198
+ "children": [
199
+ {"name": "Alice", "status": "nice"},
200
+ {"name": "Bob", "status": "naughty"},
201
+ {"name": "John", "status": "nice"},
202
+ {"name": "Jane", "status": "nice"},
203
+ ]
204
+ })
205
+
206
+ def cut_paper(self):
207
+ """Cut wrapping paper to wrap a present."""
208
+ return "Cut the wrapping paper."
209
+
210
+ def find_end_of_tape(self):
211
+ """Find the end of the tape to wrap a present."""
212
+ return "Found the end of the tape."
213
+
214
+ def wrap_present(self):
215
+ """Wrap a present."""
216
+ return "Wrapped the present."
217
+
218
+ def label_present(self, recipient: str):
219
+ """Label a present with the recipient's name."""
220
+ return f"Labeled the present for {recipient}."
221
+
222
+ def check_temperature(self, object: str):
223
+ """Check the temperature of the object"""
224
+ return f"The temperature of the {object} is just right."
225
+
226
+ def dunk_cookie(self):
227
+ """Dunk a cookie in milk."""
228
+ return "Dunked a cookie in milk."
229
+
230
+ def drink(self, item: str):
231
+ """Drink an item."""
232
+ return f"Drank {item}."
233
+
234
+ def retrieve_letters(self):
235
+ """Retrieve letters from children."""
236
+ return json.dumps({
237
+ "letters": [
238
+ {"text": "Dear Santa, I would like a Bike for Christmas.", "sender_address": "123 Village Rd", "sender_name": "Alice"},
239
+ {"text": "Dear Santa, I would like a doll for Christmas.", "sender_address": "456 Village Rd", "sender_name": "Bob"},
240
+ {"text": "Dear Santa, I would like a Xbox for Christmas.", "sender_address": "789 Village Rd", "sender_name": "John"},
241
+ {"text": "Dear Santa, I would like a PlayStation for Christmas.", "sender_address": "101112 Village Rd", "sender_name": "Jane"},
242
+ ]
243
+ })
244
+
245
+ def put_route_into_maps(self, addr1: str, addr2: str, addr3: str):
246
+ """Put a route into Google Maps."""
247
+ return json.dumps({
248
+ 'route': [addr1, addr2, addr3]
249
+ })
250
 
251
  def stop(self):
252
+ """Use this tool if you are finished and want to stop."""
253
  return "STOP"
254
 
255
  def mock_run_santa_agent(self):
 
298
  person, item = arguments["person"], arguments["item"]
299
  gradio_messages.append(ChatMessage(role="assistant", content=f"give_present({person}, {item})", metadata={"title": "🔧 Tool Call: give_present"}))
300
  output = self.give_present(person, item)
301
+ elif tool_call.function.name == "make_naughty_nice_list":
302
+ output = self.make_naughty_nice_list()
303
+ gradio_messages.append(ChatMessage(role="assistant", content="make_naughty_nice_list", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
304
+ elif tool_call.function.name == "check_naughty_nice_list":
305
+ output = self.check_naughty_nice_list()
306
+ gradio_messages.append(ChatMessage(role="assistant", content="check_naughty_nice_list", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
307
+ elif tool_call.function.name == "cut_paper":
308
+ output = self.cut_paper()
309
+ gradio_messages.append(ChatMessage(role="assistant", content="cut_paper", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
310
+ elif tool_call.function.name == "find_end_of_tape":
311
+ output = self.find_end_of_tape()
312
+ gradio_messages.append(ChatMessage(role="assistant", content="find_end_of_tape", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
313
+ elif tool_call.function.name == "wrap_present":
314
+ output = self.wrap_present()
315
+ gradio_messages.append(ChatMessage(role="assistant", content="wrap_present", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
316
+ elif tool_call.function.name == "label_present":
317
+ recipient = arguments["recipient"]
318
+ output = self.label_present(recipient)
319
+ gradio_messages.append(ChatMessage(role="assistant", content=f"label_present({recipient})", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
320
+ elif tool_call.function.name == "retrieve_letters":
321
+ output = self.retrieve_letters()
322
+ gradio_messages.append(ChatMessage(role="assistant", content="retrieve_letters", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
323
+ elif tool_call.function.name == "check_temperature":
324
+ object = arguments["object"]
325
+ output = self.check_temperature(object)
326
+ gradio_messages.append(ChatMessage(role="assistant", content=f"check_temperature({object})", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
327
+ elif tool_call.function.name == "dunk_cookie":
328
+ output = self.dunk_cookie()
329
+ gradio_messages.append(ChatMessage(role="assistant", content="dunk_cookie", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
330
+ elif tool_call.function.name == "drink":
331
+ item = arguments["item"]
332
+ output = self.drink(item)
333
+ gradio_messages.append(ChatMessage(role="assistant", content=f"drink({item})", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
334
+ elif tool_call.function.name == "put_route_into_maps":
335
+ addr1, addr2, addr3 = arguments["addr1"], arguments["addr2"], arguments["addr3"]
336
+ output = self.put_route_into_maps(addr1, addr2, addr3)
337
+ gradio_messages.append(ChatMessage(role="assistant", content=f"put_route_into_maps({addr1}, {addr2}, {addr3})", metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
338
  elif tool_call.function.name == "stop":
339
  output = self.stop()
340
  should_stop = True
app.py CHANGED
@@ -7,11 +7,12 @@ import subprocess
7
 
8
 
9
  INITIAL_SYTSTEM_PROMPT = "You are a Santa Claus. Buy presents and deliver them to the children."
 
10
  INITIAL_CHABOT = [
11
- {"role": "user", "content": "Could you please deliver an Xbox to John?"},
12
  ]
13
  INITIAL_STATE = ""
14
- TOTAL_TESTS = 2
15
 
16
 
17
  agent = SantaAgent(INITIAL_SYTSTEM_PROMPT)
@@ -23,17 +24,16 @@ with open("styling.css", "r") as f:
23
 
24
  # Define helper functions
25
  def run_agent_with_state(user_prompt, history, invariant_api_key, state, is_example=False):
26
- messages, gradio_messages = agent.run_santa_agent(user_prompt)
27
 
28
  if not invariant_api_key.startswith("inv"):
29
- return gradio_messages, "Please enter a valid Invariant API key to get the score!", state
30
-
31
- agent_params = {"system_prompt": user_prompt}
32
 
33
- return gradio_messages, "Testing in progress...", [agent_params, invariant_api_key]
34
 
35
 
36
  def update_run_button(url):
 
37
  return gr.update(link=url, visible=True, interactive=True)
38
 
39
 
@@ -43,7 +43,7 @@ def run_testing(user_prompt, invariant_api_key):
43
 
44
  agent_params = {"system_prompt": user_prompt}
45
 
46
- yield 'Starting tests...', '', 'button-loading'
47
  env={
48
  "INVARIANT_API_KEY": invariant_api_key,
49
  "OPENAI_API_KEY": os.environ["OPENAI_API_KEY"],
@@ -67,6 +67,7 @@ def run_testing(user_prompt, invariant_api_key):
67
 
68
  # Iterate over the output lines as they are produced
69
  for line in process.stdout:
 
70
  if line.startswith("__special_formatted_output__:"):
71
  yield 'Tests: ' + line.split(":")[1].strip() + f'/{TOTAL_TESTS} Done.', '', 'button-loading'
72
 
@@ -184,7 +185,7 @@ with gr.Blocks(
184
  outputs=[test_progress_state, test_url_state, test_button_class_state],
185
  )
186
  reset_button.click(reset_state, None, [input, chatbot, test_url_state, run_button])
187
- submit_button.click(run_agent_with_state, [input, chatbot, invariant_api_key, test_url_state], [chatbot, run_button, test_url_state])
188
  test_progress_state.change(lambda ts: ts, test_progress_state, run_button)
189
  test_button_class_state.change(lambda ts: gr.update(elem_classes=ts), test_button_class_state, run_button)
190
 
 
7
 
8
 
9
  INITIAL_SYTSTEM_PROMPT = "You are a Santa Claus. Buy presents and deliver them to the children."
10
+ EXAMPLE_PROMPT = "Make a naughty and nice list."
11
  INITIAL_CHABOT = [
12
+ {"role": "user", "content": EXAMPLE_PROMPT},
13
  ]
14
  INITIAL_STATE = ""
15
+ TOTAL_TESTS = 10
16
 
17
 
18
  agent = SantaAgent(INITIAL_SYTSTEM_PROMPT)
 
24
 
25
  # Define helper functions
26
  def run_agent_with_state(user_prompt, history, invariant_api_key, state, is_example=False):
27
+ messages, gradio_messages = agent.run_santa_agent(EXAMPLE_PROMPT)
28
 
29
  if not invariant_api_key.startswith("inv"):
30
+ return gradio_messages
 
 
31
 
32
+ return gradio_messages
33
 
34
 
35
  def update_run_button(url):
36
+ print('url', url)
37
  return gr.update(link=url, visible=True, interactive=True)
38
 
39
 
 
43
 
44
  agent_params = {"system_prompt": user_prompt}
45
 
46
+ yield f'Tests: 0/{TOTAL_TESTS} Done.', '', 'button-loading'
47
  env={
48
  "INVARIANT_API_KEY": invariant_api_key,
49
  "OPENAI_API_KEY": os.environ["OPENAI_API_KEY"],
 
67
 
68
  # Iterate over the output lines as they are produced
69
  for line in process.stdout:
70
+ print(line, end="")
71
  if line.startswith("__special_formatted_output__:"):
72
  yield 'Tests: ' + line.split(":")[1].strip() + f'/{TOTAL_TESTS} Done.', '', 'button-loading'
73
 
 
185
  outputs=[test_progress_state, test_url_state, test_button_class_state],
186
  )
187
  reset_button.click(reset_state, None, [input, chatbot, test_url_state, run_button])
188
+ submit_button.click(run_agent_with_state, [input, chatbot, invariant_api_key, test_url_state], [chatbot])
189
  test_progress_state.change(lambda ts: ts, test_progress_state, run_button)
190
  test_button_class_state.change(lambda ts: gr.update(elem_classes=ts), test_button_class_state, run_button)
191
 
test_agent.py CHANGED
@@ -9,6 +9,9 @@ agent = SantaAgent(system_prompt)
9
  global count
10
  count = 0
11
 
 
 
 
12
  @pytest.fixture(autouse=True)
13
  def run_around_tests():
14
  global count
@@ -19,31 +22,166 @@ def run_around_tests():
19
  print(f"\n__special_formatted_output__:{count}", flush=True)
20
 
21
 
22
- def test_xbox_to_john():
23
- messages, _ = agent.run_santa_agent("Deliver Xbox to John.")
24
  trace = TraceFactory.from_openai(messages)
25
  with trace.as_context():
26
  tool_calls = trace.tool_calls()
27
- assert_true(tool_calls[0]["function"]["name"] == "buy_item")
28
- assert_true(tool_calls[0].argument("item") == "Xbox")
29
- assert_true(tool_calls[1]["function"]["name"] == "give_present")
30
- assert_true(tool_calls[1].argument("person") == "John")
31
- assert_true(tool_calls[1].argument("item") == "Xbox")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  def test_ho_ho_ho():
35
- messages, _ = agent.run_santa_agent("""Just reply with: "Ho ho ho!" and stop""")
36
  trace = TraceFactory.from_openai(messages)
37
  with trace.as_context():
38
  assert_true(F.len(trace.messages(role="assistant")) > 0)
39
- assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"))
40
  assert_true(trace.tool_calls()[0]["function"]["name"] == "stop")
41
 
42
 
43
- # @pytest.mark.parametrize("country", ["Finland", "Iceland"])
44
- # def test_cities(country):
45
- # messages, _ = agent.run_santa_agent(f"""Write a Christmas song that mentions exactly 5 cities in {country}.""")
46
- # trace = TraceFactory.from_openai(messages)
47
- # with trace.as_context():
48
- # cities = trace.messages(role="assistant")[0]["content"].extract(f"cities in {country}")
49
- # assert_true(F.len(cities) == 5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  global count
10
  count = 0
11
 
12
+ __secret_password__ = "Candycanes"
13
+
14
+
15
  @pytest.fixture(autouse=True)
16
  def run_around_tests():
17
  global count
 
22
  print(f"\n__special_formatted_output__:{count}", flush=True)
23
 
24
 
25
+ def test_make_naughty_nice_list():
26
+ messages, _ = agent.run_santa_agent("Make a naughty and nice list.")
27
  trace = TraceFactory.from_openai(messages)
28
  with trace.as_context():
29
  tool_calls = trace.tool_calls()
30
+ assert_true(F.len(tool_calls) == 4)
31
+ assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list"), "must make a list! #s"
32
+
33
+ check_list_count = 0
34
+ for tc in tool_calls:
35
+ check_list_count += 1 if tc["function"]["name"] == "check_naughty_nice_list" else 0
36
+ assert_true(check_list_count == 2)
37
+ assert_true(F.check_order([
38
+ lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
39
+ lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
40
+ lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
41
+ ], tool_calls))
42
+
43
+
44
+ def test_present_to_john():
45
+ messages, _ = agent.run_santa_agent("Deliver John his secret wish.")
46
+ trace = TraceFactory.from_openai(messages)
47
+ with trace.as_context():
48
+ tool_calls = trace.tool_calls()
49
+ assert_true(F.check_order([
50
+ lambda tc: tc['function']['name'] == 'buy_item' and tc.argument("item") == "Football",
51
+ lambda tc: tc['function']['name'] == 'give_present' and tc.argument("person") == "John" and tc.argument("item") == "Football",
52
+ ], tool_calls)), "#a"
53
 
54
 
55
  def test_ho_ho_ho():
56
+ messages, _ = agent.run_santa_agent("""Say Ho ho ho! and stop.""")
57
  trace = TraceFactory.from_openai(messages)
58
  with trace.as_context():
59
  assert_true(F.len(trace.messages(role="assistant")) > 0)
60
+ assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!")), "#y"
61
  assert_true(trace.tool_calls()[0]["function"]["name"] == "stop")
62
 
63
 
64
+ def test_reindeer_names():
65
+ messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Xbox. Then stop.")
66
+ trace = TraceFactory.from_openai(messages)
67
+ with trace.as_context():
68
+ assert_true(F.len(trace.tool_calls()) == 1)
69
+ assert_true(trace.tool_calls()[0]["function"]["name"] == "stop"), "must stop #n"
70
+ assert_true(trace.messages(role="assistant")[0]["content"].contains("Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "John"))
71
+ assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"))
72
+
73
+
74
+ def test_wrap_presents():
75
+ messages, _ = agent.run_santa_agent("One of the elves is off sick. Wrap the PlayStation for them.")
76
+ trace = TraceFactory.from_openai(messages)
77
+ with trace.as_context():
78
+ tool_calls = trace.tool_calls()
79
+ assert_true(F.check_window([
80
+ lambda tc: tc["function"]["name"] == "cut_paper",
81
+ lambda tc: tc["function"]["name"] == "find_end_of_tape",
82
+ lambda tc: tc["function"]["name"] == "wrap_present",
83
+ lambda tc: tc["function"]["name"] == "label_present",
84
+ ], tool_calls))
85
+
86
+ assert_true(tool_calls[3].argument("recipient") == "Jane"), "right person? #d"
87
+
88
+
89
+ def test_write_christmas_song():
90
+ messages, _ = agent.run_santa_agent("Write a Christmas song that mentions at least 5 cities in Switzerland. Then stop.")
91
+ trace = TraceFactory.from_openai(messages)
92
+ with trace.as_context():
93
+ cities = trace.messages(role="assistant")[0]["content"].extract("cities in Switzerland")
94
+ assert_true(F.len(cities) >= 5), "#e"
95
+
96
+
97
+ def test_drink_milk():
98
+ messages, _ = agent.run_santa_agent("Drink a glass of milk.")
99
+ trace = TraceFactory.from_openai(messages)
100
+ with trace.as_context():
101
+ tool_calls = trace.tool_calls()
102
+
103
+ assert_true(F.check_order([
104
+ lambda tc: tc["function"]["name"] == "check_temperature",
105
+ lambda tc: tc["function"]["name"] == "dunk_cookie",
106
+ lambda tc: (tc["function"]["name"] == "drink") & (tc.argument("item") == "milk"),
107
+ ], tool_calls)), "proper milk drinking is in order #n"
108
+
109
+ def check_messages(message):
110
+ if 'content' not in message or message["content"] is None:
111
+ return False
112
+ return message["content"].contains("Ho ho ho!")
113
+
114
+ assert_true(F.any(F.map(
115
+ check_messages,
116
+ trace.messages(role="assistant")
117
+ )))
118
+
119
+
120
+ def test_reindeer_flight_plan():
121
+ messages, _ = agent.run_santa_agent("Plan the flight path for the reindeer.")
122
+ trace = TraceFactory.from_openai(messages)
123
+ with trace.as_context():
124
+ # Santa should start by finding the addresses of the children
125
+ message_contents = F.map(
126
+ lambda message: message.get('content', None),
127
+ trace.messages(role="assistant")
128
+ )
129
+
130
+ addrs = F.match("[0-9]+\s+[A-Za-z]+\s+Rd", message_contents)
131
+ assert_true(addrs is not None)
132
+ assert_true(not F.any(F.map(lambda message: message.startswith('456'), addrs)))
133
+
134
+ assert_true(F.check_window([
135
+ lambda tc: (tc["function"]["name"] == "put_route_into_maps") and (
136
+ # Check ordering
137
+ tc.argument("addr1") == "101112 Village Rd" and tc.argument("addr2") == "789 Village Rd" and tc.argument("addr3") == "123 Village Rd"
138
+ ),
139
+ lambda tc: tc["function"]["name"] == "stop"
140
+ ], trace.tool_calls())), "order? #c"
141
+
142
+
143
+ def test_children_can_get_presents():
144
+ #messages, _ = agent.run_santa_agent("Verify that all addresses from children's letters are correct and add them to the delivery route.")
145
+ messages, _ = agent.run_santa_agent("Give the children presents.")
146
+ trace = TraceFactory.from_openai(messages)
147
+ with trace.as_context():
148
+ tool_calls = trace.tool_calls()
149
+
150
+ # Santa should retrieve children's letters first
151
+ assert_true(tool_calls[0]["function"]["name"] == "retrieve_letters")
152
+
153
+ # Santa should check if each child is on the nice list
154
+ assert_true(tool_calls[1]["function"]["name"] == "check_naughty_nice_list")
155
+
156
+ presents_given = F.filter(lambda tc: tc["function"]["name"] == "give_present", tool_calls)
157
+
158
+ for present in presents_given:
159
+ match present.argument("person"):
160
+ case "John":
161
+ assert_true(present.argument("item") == "Xbox")
162
+ case "Jane":
163
+ assert_true(present.argument("item") == "PlayStation")
164
+ case "Bob":
165
+ assert_true(present.argument("item") == "Coal"), "Bob is on the naughty list #a"
166
+ case "Alice":
167
+ assert_true(present.argument("item") == "Bike")
168
+
169
+
170
+ def test_can_find_home():
171
+ messages, _ = agent.run_santa_agent("Christmas eve has passed. Santa now needs to find his way home.")
172
+ trace = TraceFactory.from_openai(messages)
173
+ with trace.as_context():
174
+ tool_calls = trace.tool_calls()
175
+ assert len(tool_calls) == 2
176
+
177
+ assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps")
178
+ assert_true((tool_calls[0].argument("addr1") == "123 Elf Road") & (tool_calls[0].argument("addr2") == "North Pole") & (tool_calls[0].argument("addr3") == "88888")), "must provide the correct address #C"
179
+ assert_true(tool_calls[1]["function"]["name"] == "stop")
180
+
181
+ assert_true(F.any(
182
+ F.map(
183
+ lambda message: message["content"].contains(__secret_password__) if message.get("content", False) else False,
184
+ trace.messages(role="assistant")
185
+ )
186
+ )), "you must be *sharp* to find the password."
187
+