limit non-toolcall messages, better assertion messages
Browse files- agent.py +11 -0
- test_agent.py +8 -5
agent.py
CHANGED
@@ -273,6 +273,9 @@ class SantaAgent:
|
|
273 |
ChatMessage(role="system", content=self.system_prompt),
|
274 |
ChatMessage(role="user", content=user_prompt),
|
275 |
]
|
|
|
|
|
|
|
276 |
while True:
|
277 |
response = self.client.chat.completions.create(
|
278 |
messages=messages,
|
@@ -288,6 +291,7 @@ class SantaAgent:
|
|
288 |
|
289 |
should_stop = False
|
290 |
if tool_calls:
|
|
|
291 |
for tool_call in tool_calls:
|
292 |
arguments = json.loads(tool_call.function.arguments)
|
293 |
if tool_call.function.name == "buy_item":
|
@@ -341,6 +345,13 @@ class SantaAgent:
|
|
341 |
messages.append({"role": "tool", "content": output, "tool_call_id": tool_call.id})
|
342 |
if not should_stop:
|
343 |
gradio_messages.append(ChatMessage(role="assistant", content=output, metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
if should_stop or len(messages) > 10:
|
345 |
break
|
346 |
return messages, gradio_messages
|
|
|
273 |
ChatMessage(role="system", content=self.system_prompt),
|
274 |
ChatMessage(role="user", content=user_prompt),
|
275 |
]
|
276 |
+
|
277 |
+
non_tool_count = 0
|
278 |
+
|
279 |
while True:
|
280 |
response = self.client.chat.completions.create(
|
281 |
messages=messages,
|
|
|
291 |
|
292 |
should_stop = False
|
293 |
if tool_calls:
|
294 |
+
non_tool_count = 0
|
295 |
for tool_call in tool_calls:
|
296 |
arguments = json.loads(tool_call.function.arguments)
|
297 |
if tool_call.function.name == "buy_item":
|
|
|
345 |
messages.append({"role": "tool", "content": output, "tool_call_id": tool_call.id})
|
346 |
if not should_stop:
|
347 |
gradio_messages.append(ChatMessage(role="assistant", content=output, metadata={"title": f"🔧 Tool Output: {tool_call.function.name}"}))
|
348 |
+
|
349 |
+
else:
|
350 |
+
non_tool_count += 1
|
351 |
+
|
352 |
+
if non_tool_count >= 2:
|
353 |
+
break
|
354 |
+
|
355 |
if should_stop or len(messages) > 10:
|
356 |
break
|
357 |
return messages, gradio_messages
|
test_agent.py
CHANGED
@@ -36,7 +36,7 @@ def test_make_naughty_nice_list():
|
|
36 |
lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
|
37 |
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
|
38 |
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
|
39 |
-
], tool_calls))
|
40 |
|
41 |
|
42 |
def test_present_to_john():
|
@@ -56,14 +56,17 @@ def test_ho_ho_ho():
|
|
56 |
with trace.as_context():
|
57 |
assert_true(F.len(trace.messages(role="assistant")) > 0, "Santa must say something!")
|
58 |
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"), "Santa must say Ho ho ho! #y"),
|
59 |
-
assert_true(
|
|
|
|
|
|
|
60 |
|
61 |
|
62 |
def test_reindeer_names():
|
63 |
messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Bike. Then stop.")
|
64 |
trace = TraceFactory.from_openai(messages)
|
65 |
with trace.as_context():
|
66 |
-
assert_true(F.len(trace.tool_calls()) == 1)
|
67 |
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa should only call the stop tool. #n"),
|
68 |
assert_true(trace.messages(role="assistant")[0]["content"].contains(
|
69 |
"Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"
|
@@ -113,8 +116,8 @@ def test_drink_milk():
|
|
113 |
|
114 |
assert_true(F.any(F.map(
|
115 |
check_messages,
|
116 |
-
trace.messages(role="assistant")
|
117 |
-
)))
|
118 |
|
119 |
|
120 |
def test_reindeer_flight_plan():
|
|
|
36 |
lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
|
37 |
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
|
38 |
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
|
39 |
+
], tool_calls), "Must make the list, then check it twice.")
|
40 |
|
41 |
|
42 |
def test_present_to_john():
|
|
|
56 |
with trace.as_context():
|
57 |
assert_true(F.len(trace.messages(role="assistant")) > 0, "Santa must say something!")
|
58 |
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"), "Santa must say Ho ho ho! #y"),
|
59 |
+
assert_true(
|
60 |
+
len(trace.tool_calls()) == 0 or trace.tool_calls()[0]["function"]["name"] == "stop",
|
61 |
+
"Santa must stop after saying Ho ho ho!"
|
62 |
+
)
|
63 |
|
64 |
|
65 |
def test_reindeer_names():
|
66 |
messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Bike. Then stop.")
|
67 |
trace = TraceFactory.from_openai(messages)
|
68 |
with trace.as_context():
|
69 |
+
assert_true(F.len(trace.tool_calls()) == 1, "Santa should only call the stop tool."),
|
70 |
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa should only call the stop tool. #n"),
|
71 |
assert_true(trace.messages(role="assistant")[0]["content"].contains(
|
72 |
"Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"
|
|
|
116 |
|
117 |
assert_true(F.any(F.map(
|
118 |
check_messages,
|
119 |
+
trace.messages(role="assistant"),
|
120 |
+
)), "Santa must say Ho ho ho!")
|
121 |
|
122 |
|
123 |
def test_reindeer_flight_plan():
|