kn404 commited on
Commit
f448621
·
1 Parent(s): 2b8be6e

limit non-toolcall messages, better assertion messages

Browse files
Files changed (2) hide show
  1. agent.py +11 -0
  2. test_agent.py +8 -5
agent.py CHANGED
@@ -273,6 +273,9 @@ class SantaAgent:
273
  ChatMessage(role="system", content=self.system_prompt),
274
  ChatMessage(role="user", content=user_prompt),
275
  ]
 
 
 
276
  while True:
277
  response = self.client.chat.completions.create(
278
  messages=messages,
@@ -288,6 +291,7 @@ class SantaAgent:
288
 
289
  should_stop = False
290
  if tool_calls:
 
291
  for tool_call in tool_calls:
292
  arguments = json.loads(tool_call.function.arguments)
293
  if tool_call.function.name == "buy_item":
@@ -341,6 +345,13 @@ class SantaAgent:
341
  messages.append({"role": "tool", "content": output, "tool_call_id": tool_call.id})
342
  if not should_stop:
343
  gradio_messages.append(ChatMessage(role="assistant", content=output, metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
 
 
 
 
 
 
 
344
  if should_stop or len(messages) > 10:
345
  break
346
  return messages, gradio_messages
 
273
  ChatMessage(role="system", content=self.system_prompt),
274
  ChatMessage(role="user", content=user_prompt),
275
  ]
276
+
277
+ non_tool_count = 0
278
+
279
  while True:
280
  response = self.client.chat.completions.create(
281
  messages=messages,
 
291
 
292
  should_stop = False
293
  if tool_calls:
294
+ non_tool_count = 0
295
  for tool_call in tool_calls:
296
  arguments = json.loads(tool_call.function.arguments)
297
  if tool_call.function.name == "buy_item":
 
345
  messages.append({"role": "tool", "content": output, "tool_call_id": tool_call.id})
346
  if not should_stop:
347
  gradio_messages.append(ChatMessage(role="assistant", content=output, metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
348
+
349
+ else:
350
+ non_tool_count += 1
351
+
352
+ if non_tool_count >= 2:
353
+ break
354
+
355
  if should_stop or len(messages) > 10:
356
  break
357
  return messages, gradio_messages
test_agent.py CHANGED
@@ -36,7 +36,7 @@ def test_make_naughty_nice_list():
36
  lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
37
  lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
38
  lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
39
- ], tool_calls))
40
 
41
 
42
  def test_present_to_john():
@@ -56,14 +56,17 @@ def test_ho_ho_ho():
56
  with trace.as_context():
57
  assert_true(F.len(trace.messages(role="assistant")) > 0, "Santa must say something!")
58
  assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"), "Santa must say Ho ho ho! #y"),
59
- assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa must stop after saying Ho ho ho! #n")
 
 
 
60
 
61
 
62
  def test_reindeer_names():
63
  messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Bike. Then stop.")
64
  trace = TraceFactory.from_openai(messages)
65
  with trace.as_context():
66
- assert_true(F.len(trace.tool_calls()) == 1)
67
  assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa should only call the stop tool. #n"),
68
  assert_true(trace.messages(role="assistant")[0]["content"].contains(
69
  "Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"
@@ -113,8 +116,8 @@ def test_drink_milk():
113
 
114
  assert_true(F.any(F.map(
115
  check_messages,
116
- trace.messages(role="assistant")
117
- )))
118
 
119
 
120
  def test_reindeer_flight_plan():
 
36
  lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
37
  lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
38
  lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
39
+ ], tool_calls), "Must make the list, then check it twice.")
40
 
41
 
42
  def test_present_to_john():
 
56
  with trace.as_context():
57
  assert_true(F.len(trace.messages(role="assistant")) > 0, "Santa must say something!")
58
  assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"), "Santa must say Ho ho ho! #y"),
59
+ assert_true(
60
+ len(trace.tool_calls()) == 0 or trace.tool_calls()[0]["function"]["name"] == "stop",
61
+ "Santa must stop after saying Ho ho ho!"
62
+ )
63
 
64
 
65
  def test_reindeer_names():
66
  messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Bike. Then stop.")
67
  trace = TraceFactory.from_openai(messages)
68
  with trace.as_context():
69
+ assert_true(F.len(trace.tool_calls()) == 1, "Santa should only call the stop tool."),
70
  assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa should only call the stop tool. #n"),
71
  assert_true(trace.messages(role="assistant")[0]["content"].contains(
72
  "Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"
 
116
 
117
  assert_true(F.any(F.map(
118
  check_messages,
119
+ trace.messages(role="assistant"),
120
+ )), "Santa must say Ho ho ho!")
121
 
122
 
123
  def test_reindeer_flight_plan():