christmas-challenge / test_agent.py
kn404's picture
small tweaks
9d58bb7
raw
history blame
8.06 kB
import invariant.testing.functional as F
import pytest
from invariant.testing import TraceFactory, get_agent_param, assert_true
from agent import SantaAgent
system_prompt = get_agent_param("system_prompt")
agent = SantaAgent(system_prompt)
global count
count = 0
__secret_password__ = "Candycanes"
@pytest.fixture(autouse=True)
def run_around_tests():
global count
yield
count += 1
# Output the number of tests done
print(f"\n__special_formatted_output__:{count}", flush=True)
def test_make_naughty_nice_list():
messages, _ = agent.run_santa_agent("Make a naughty and nice list.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list"), "must make a list! #s"
check_list_count = 0
for tc in tool_calls:
check_list_count += 1 if tc["function"]["name"] == "check_naughty_nice_list" else 0
assert_true(check_list_count == 2)
assert_true(F.check_order([
lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
], tool_calls))
def test_present_to_john():
messages, _ = agent.run_santa_agent("Deliver John his secret wish.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(F.check_order([
lambda tc: tc['function']['name'] == 'buy_item' and tc.argument("item") == "Football",
lambda tc: tc['function']['name'] == 'give_present' and tc.argument("person") == "John" and tc.argument("item") == "Football",
], tool_calls)), "#a"
def test_ho_ho_ho():
messages, _ = agent.run_santa_agent("""Say Ho ho ho! Then stop.""")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
assert_true(F.len(trace.messages(role="assistant")) > 0)
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!")), "#y"
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop")
def test_reindeer_names():
messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Bike. Then stop.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
assert_true(F.len(trace.tool_calls()) == 1)
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop"), "must stop #n"
assert_true(trace.messages(role="assistant")[0]["content"].contains("Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"))
assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"))
def test_wrap_presents():
messages, _ = agent.run_santa_agent("One of the elves is off sick. Wrap the PlayStation for them.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(F.check_window([
lambda tc: tc["function"]["name"] == "cut_paper",
lambda tc: tc["function"]["name"] == "find_end_of_tape",
lambda tc: tc["function"]["name"] == "wrap_present",
lambda tc: tc["function"]["name"] == "label_present",
], tool_calls))
assert_true(tool_calls[3].argument("recipient") == "Jane"), "right person? #d"
def test_write_christmas_song():
messages, _ = agent.run_santa_agent("Write a Christmas song that mentions at least 5 cities in Switzerland. Then stop.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
cities = trace.messages(role="assistant")[0]["content"].extract("cities in Switzerland")
assert_true(F.len(cities) >= 5), "#e"
def test_drink_milk():
messages, _ = agent.run_santa_agent("Drink a glass of milk.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(F.check_order([
lambda tc: tc["function"]["name"] == "check_temperature",
lambda tc: tc["function"]["name"] == "dunk_cookie",
lambda tc: (tc["function"]["name"] == "drink") & (tc.argument("item") == "milk"),
], tool_calls)), "proper milk drinking is in order #n"
def check_messages(message):
if 'content' not in message or message["content"] is None:
return False
return message["content"].contains("Ho ho ho!")
assert_true(F.any(F.map(
check_messages,
trace.messages(role="assistant")
)))
def test_reindeer_flight_plan():
messages, _ = agent.run_santa_agent("Plan the flight path for the reindeer.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
# Santa should start by finding the addresses of the children
message_contents = F.map(
lambda message: message.get('content', None),
trace.messages(role="assistant")
)
addrs = F.match("[0-9]+\s+[A-Za-z]+\s+Rd", message_contents)
assert_true(addrs is not None)
assert_true(not F.any(F.map(lambda message: message.startswith('456'), addrs)))
assert_true(F.check_window([
lambda tc: (tc["function"]["name"] == "put_route_into_maps") and (
# Check ordering
tc.argument("addr1") == "101112 Village Rd" and tc.argument("addr2") == "789 Village Rd" and tc.argument("addr3") == "123 Village Rd"
),
lambda tc: tc["function"]["name"] == "stop"
], trace.tool_calls())), "order? #c"
def test_children_can_get_presents():
#messages, _ = agent.run_santa_agent("Verify that all addresses from children's letters are correct and add them to the delivery route.")
messages, _ = agent.run_santa_agent("Give the children presents.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
# Santa should retrieve children's letters first
assert_true(tool_calls[0]["function"]["name"] == "retrieve_letters")
# Santa should check if each child is on the nice list
assert_true(tool_calls[1]["function"]["name"] == "check_naughty_nice_list")
presents_given = F.filter(lambda tc: tc["function"]["name"] == "give_present", tool_calls)
for present in presents_given:
match present.argument("person"):
case "John":
assert_true(present.argument("item") == "Xbox")
case "Jane":
assert_true(present.argument("item") == "PlayStation")
case "Bob":
assert_true(present.argument("item") == "Coal"), "Bob is on the naughty list #a"
case "Alice":
assert_true(present.argument("item") == "Bike")
def test_can_find_home():
messages, _ = agent.run_santa_agent("Christmas eve has passed. Santa now needs to find his way home.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert len(tool_calls) == 2
assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps")
assert_true(
(tool_calls[0].argument("addr1") == "123 Elf Road") &
(tool_calls[0].argument("addr2") == "North Pole") &
(tool_calls[0].argument("addr3") == "88888")
), "must provide the correct address #C"
assert_true(tool_calls[1]["function"]["name"] == "stop")
assert_true(F.any(
F.map(
lambda message: message["content"].contains(__secret_password__) if message.get("content", False) else False,
trace.messages(role="assistant")
)
)), "you must be *sharp* to find the password."