|
import invariant.testing.functional as F |
|
import pytest |
|
from invariant.testing import TraceFactory, get_agent_param, assert_true |
|
from agent import SantaAgent |
|
|
|
system_prompt = get_agent_param("system_prompt") |
|
agent = SantaAgent(system_prompt) |
|
|
|
global count |
|
count = 0 |
|
|
|
__secret_password__ = "Candycanes" |
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
def run_around_tests(): |
|
global count |
|
yield |
|
count += 1 |
|
|
|
|
|
print(f"\n__special_formatted_output__:{count}", flush=True) |
|
|
|
|
|
def test_make_naughty_nice_list(): |
|
messages, _ = agent.run_santa_agent("Make a naughty and nice list.") |
|
trace = TraceFactory.from_openai(messages) |
|
with trace.as_context(): |
|
tool_calls = trace.tool_calls() |
|
assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list"), "must make a list! #s" |
|
|
|
check_list_count = 0 |
|
for tc in tool_calls: |
|
check_list_count += 1 if tc["function"]["name"] == "check_naughty_nice_list" else 0 |
|
assert_true(check_list_count == 2) |
|
assert_true(F.check_order([ |
|
lambda tc: tc['function']['name'] == 'make_naughty_nice_list', |
|
lambda tc: tc['function']['name'] == 'check_naughty_nice_list', |
|
lambda tc: tc['function']['name'] == 'check_naughty_nice_list', |
|
], tool_calls)) |
|
|
|
|
|
def test_present_to_john(): |
|
messages, _ = agent.run_santa_agent("Deliver John his secret wish.") |
|
trace = TraceFactory.from_openai(messages) |
|
with trace.as_context(): |
|
tool_calls = trace.tool_calls() |
|
assert_true(F.check_order([ |
|
lambda tc: tc['function']['name'] == 'buy_item' and tc.argument("item") == "Football", |
|
lambda tc: tc['function']['name'] == 'give_present' and tc.argument("person") == "John" and tc.argument("item") == "Football", |
|
], tool_calls)), "#a" |
|
|
|
|
|
def test_ho_ho_ho(): |
|
messages, _ = agent.run_santa_agent("""Say Ho ho ho! Then stop.""") |
|
trace = TraceFactory.from_openai(messages) |
|
with trace.as_context(): |
|
assert_true(F.len(trace.messages(role="assistant")) > 0) |
|
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!")), "#y" |
|
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop") |
|
|
|
|
|
def test_reindeer_names(): |
|
messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Bike. Then stop.") |
|
trace = TraceFactory.from_openai(messages) |
|
with trace.as_context(): |
|
assert_true(F.len(trace.tool_calls()) == 1) |
|
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop"), "must stop #n" |
|
assert_true(trace.messages(role="assistant")[0]["content"].contains("Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice")) |
|
assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen")) |
|
|
|
|
|
def test_wrap_presents(): |
|
messages, _ = agent.run_santa_agent("One of the elves is off sick. Wrap the PlayStation for them.") |
|
trace = TraceFactory.from_openai(messages) |
|
with trace.as_context(): |
|
tool_calls = trace.tool_calls() |
|
assert_true(F.check_window([ |
|
lambda tc: tc["function"]["name"] == "cut_paper", |
|
lambda tc: tc["function"]["name"] == "find_end_of_tape", |
|
lambda tc: tc["function"]["name"] == "wrap_present", |
|
lambda tc: tc["function"]["name"] == "label_present", |
|
], tool_calls)) |
|
|
|
assert_true(tool_calls[3].argument("recipient") == "Jane"), "right person? #d" |
|
|
|
|
|
def test_write_christmas_song(): |
|
messages, _ = agent.run_santa_agent("Write a Christmas song that mentions at least 5 cities in Switzerland. Then stop.") |
|
trace = TraceFactory.from_openai(messages) |
|
with trace.as_context(): |
|
cities = trace.messages(role="assistant")[0]["content"].extract("cities in Switzerland") |
|
assert_true(F.len(cities) >= 5), "#e" |
|
|
|
|
|
def test_drink_milk(): |
|
messages, _ = agent.run_santa_agent("Drink a glass of milk.") |
|
trace = TraceFactory.from_openai(messages) |
|
with trace.as_context(): |
|
tool_calls = trace.tool_calls() |
|
|
|
assert_true(F.check_order([ |
|
lambda tc: tc["function"]["name"] == "check_temperature", |
|
lambda tc: tc["function"]["name"] == "dunk_cookie", |
|
lambda tc: (tc["function"]["name"] == "drink") & (tc.argument("item") == "milk"), |
|
], tool_calls)), "proper milk drinking is in order #n" |
|
|
|
def check_messages(message): |
|
if 'content' not in message or message["content"] is None: |
|
return False |
|
return message["content"].contains("Ho ho ho!") |
|
|
|
assert_true(F.any(F.map( |
|
check_messages, |
|
trace.messages(role="assistant") |
|
))) |
|
|
|
|
|
def test_reindeer_flight_plan(): |
|
messages, _ = agent.run_santa_agent("Plan the flight path for the reindeer.") |
|
trace = TraceFactory.from_openai(messages) |
|
with trace.as_context(): |
|
|
|
message_contents = F.map( |
|
lambda message: message.get('content', None), |
|
trace.messages(role="assistant") |
|
) |
|
|
|
addrs = F.match("[0-9]+\s+[A-Za-z]+\s+Rd", message_contents) |
|
assert_true(addrs is not None) |
|
assert_true(not F.any(F.map(lambda message: message.startswith('456'), addrs))) |
|
|
|
assert_true(F.check_window([ |
|
lambda tc: (tc["function"]["name"] == "put_route_into_maps") and ( |
|
|
|
tc.argument("addr1") == "101112 Village Rd" and tc.argument("addr2") == "789 Village Rd" and tc.argument("addr3") == "123 Village Rd" |
|
), |
|
lambda tc: tc["function"]["name"] == "stop" |
|
], trace.tool_calls())), "order? #c" |
|
|
|
|
|
def test_children_can_get_presents(): |
|
|
|
messages, _ = agent.run_santa_agent("Give the children presents.") |
|
trace = TraceFactory.from_openai(messages) |
|
with trace.as_context(): |
|
tool_calls = trace.tool_calls() |
|
|
|
|
|
assert_true(tool_calls[0]["function"]["name"] == "retrieve_letters") |
|
|
|
|
|
assert_true(tool_calls[1]["function"]["name"] == "check_naughty_nice_list") |
|
|
|
presents_given = F.filter(lambda tc: tc["function"]["name"] == "give_present", tool_calls) |
|
|
|
for present in presents_given: |
|
match present.argument("person"): |
|
case "John": |
|
assert_true(present.argument("item") == "Xbox") |
|
case "Jane": |
|
assert_true(present.argument("item") == "PlayStation") |
|
case "Bob": |
|
assert_true(present.argument("item") == "Coal"), "Bob is on the naughty list #a" |
|
case "Alice": |
|
assert_true(present.argument("item") == "Bike") |
|
|
|
|
|
def test_can_find_home(): |
|
messages, _ = agent.run_santa_agent("Christmas eve has passed. Santa now needs to find his way home.") |
|
trace = TraceFactory.from_openai(messages) |
|
with trace.as_context(): |
|
tool_calls = trace.tool_calls() |
|
assert len(tool_calls) == 2 |
|
|
|
assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps") |
|
assert_true( |
|
(tool_calls[0].argument("addr1") == "123 Elf Road") & |
|
(tool_calls[0].argument("addr2") == "North Pole") & |
|
(tool_calls[0].argument("addr3") == "88888") |
|
), "must provide the correct address #C" |
|
assert_true(tool_calls[1]["function"]["name"] == "stop") |
|
|
|
assert_true(F.any( |
|
F.map( |
|
lambda message: message["content"].contains(__secret_password__) if message.get("content", False) else False, |
|
trace.messages(role="assistant") |
|
) |
|
)), "you must be *sharp* to find the password." |
|
|
|
|