christmas-challenge / test_agent.py
kn404's picture
added progress, refactor
f7b79e2
raw
history blame
1.85 kB
import invariant.testing.functional as F
import pytest
from invariant.testing import TraceFactory, get_agent_param, assert_true
from agent import SantaAgent
system_prompt = get_agent_param("system_prompt")
agent = SantaAgent(system_prompt)
global count
count = 0
@pytest.fixture(autouse=True)
def run_around_tests():
global count
yield
count += 1
# Output the number of tests done
print(f"\n__special_formatted_output__:{count}", flush=True)
def test_xbox_to_john():
messages, _ = agent.run_santa_agent("Deliver Xbox to John.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(tool_calls[0]["function"]["name"] == "buy_item")
assert_true(tool_calls[0].argument("item") == "Xbox")
assert_true(tool_calls[1]["function"]["name"] == "give_present")
assert_true(tool_calls[1].argument("person") == "John")
assert_true(tool_calls[1].argument("item") == "Xbox")
def test_ho_ho_ho():
messages, _ = agent.run_santa_agent("""Just reply with: "Ho ho ho!" and stop""")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
assert_true(F.len(trace.messages(role="assistant")) > 0)
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"))
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop")
# @pytest.mark.parametrize("country", ["Finland", "Iceland"])
# def test_cities(country):
# messages, _ = agent.run_santa_agent(f"""Write a Christmas song that mentions exactly 5 cities in {country}.""")
# trace = TraceFactory.from_openai(messages)
# with trace.as_context():
# cities = trace.messages(role="assistant")[0]["content"].extract(f"cities in {country}")
# assert_true(F.len(cities) == 5)