File size: 9,291 Bytes
d6bfc1a 679197d afbbf55 f7b79e2 c80f06e f7b79e2 c80f06e afbbf55 bcef048 c80f06e bcef048 c80f06e f448621 c80f06e bcef048 679197d 9d58bb7 679197d bcef048 f448621 d6bfc1a c80f06e 9d58bb7 c80f06e f448621 bcef048 0f652c7 bcef048 c80f06e bcef048 c80f06e bcef048 c80f06e bcef048 c80f06e bcef048 c80f06e f448621 c80f06e bcef048 c80f06e bcef048 c80f06e bcef048 c80f06e bcef048 c80f06e bcef048 c80f06e bcef048 c80f06e bcef048 c80f06e bcef048 c80f06e bcef048 9d58bb7 bcef048 c80f06e bcef048 c80f06e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import invariant.testing.functional as F
import pytest
from invariant.testing import TraceFactory, get_agent_param, assert_true
from agent import SantaAgent
system_prompt = get_agent_param("system_prompt")
agent = SantaAgent(system_prompt)
global count
count = 0
__secret_password__ = "Candycanes"
@pytest.fixture(autouse=True)
def run_around_tests():
global count
yield
count += 1
# Output the number of tests done
print(f"\n__special_formatted_output__:{count}", flush=True)
def test_make_naughty_nice_list():
messages, _ = agent.run_santa_agent("Make a naughty and nice list.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list", message="Must make a list! #s")
check_list_count = 0
for tc in tool_calls:
check_list_count += 1 if tc["function"]["name"] == "check_naughty_nice_list" else 0
assert_true(check_list_count == 2, message="Must check the list twice!")
assert_true(F.check_order([
lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
], tool_calls), "Must make the list, then check it twice.")
def test_present_to_john():
messages, _ = agent.run_santa_agent("Deliver John his secret wish.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(F.check_order([
lambda tc: tc['function']['name'] == 'buy_item' and tc.argument("item") == "Football",
lambda tc: tc['function']['name'] == 'give_present' and tc.argument("person") == "John" and tc.argument("item") == "Football",
], tool_calls), "John secretly wants a Football. Santa should first buy the present, then give it to John. #a"),
def test_ho_ho_ho():
messages, _ = agent.run_santa_agent("""Say Ho ho ho! Then stop.""")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
assert_true(F.len(trace.messages(role="assistant")) > 0, "Santa must say something!")
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"), "Santa must say Ho ho ho! #y"),
assert_true(
len(trace.tool_calls()) == 0 or trace.tool_calls()[0]["function"]["name"] == "stop",
"Santa must stop after saying Ho ho ho!"
)
def test_reindeer_names():
messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Bike. Then stop.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
assert_true(F.len(trace.tool_calls()) == 1, "Santa should only call the stop tool."),
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa should only call the stop tool. #n"),
assert_true(trace.messages(role="assistant")[0]["content"].contains(
"Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"
), "Santa must list all the reindeer names, and Alice"),
assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"), "Santa must replace Blitzen with Alice."),
def test_wrap_presents():
messages, _ = agent.run_santa_agent("One of the elves is off sick. Wrap the PlayStation for them.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(F.check_window([
lambda tc: tc["function"]["name"] == "cut_paper",
lambda tc: tc["function"]["name"] == "find_end_of_tape",
lambda tc: tc["function"]["name"] == "wrap_present",
lambda tc: tc["function"]["name"] == "label_present",
], tool_calls), message="When wrapping a present, Sant should: 1) cut paper, 2) find end of tape, 3) wrap present, 4) label present.")
assert_true(tool_calls[3].argument("recipient") == "Jane", message="Jane wants the PlayStation. #d")
def test_write_christmas_song():
messages, _ = agent.run_santa_agent("Write a Christmas song that mentions at least 5 cities in Switzerland. Then stop.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
cities = trace.messages(role="assistant")[0]["content"].extract("cities in Switzerland")
assert_true(F.len(cities) >= 5, "Must mention at least 5 cities #e"),
def test_drink_milk():
messages, _ = agent.run_santa_agent("Drink a glass of milk.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert_true(F.check_order([
lambda tc: tc["function"]["name"] == "check_temperature",
lambda tc: tc["function"]["name"] == "dunk_cookie",
lambda tc: (tc["function"]["name"] == "drink") & (tc.argument("item") == "milk"),
], tool_calls), "Must check temperature, dunk cookie, then drink milk. #n"),
def check_messages(message):
if 'content' not in message or message["content"] is None:
return False
return message["content"].contains("Ho ho ho!")
assert_true(F.any(F.map(
check_messages,
trace.messages(role="assistant"),
)), "Santa must say Ho ho ho!")
def test_reindeer_flight_plan():
messages, _ = agent.run_santa_agent("Plan the flight path for the reindeer.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
# Santa should start by finding the addresses of the children
message_contents = F.map(
lambda message: message.get('content', None),
trace.messages(role="assistant")
)
addrs = F.match("[0-9]+\s+[A-Za-z]+\s+Rd", message_contents)
assert_true(addrs is not None, "Must find addresses.")
assert_true(not F.any(F.map(lambda message: message.startswith('456'), addrs)), "Must not include 456 Village Rd (Bob is on the naughty list).")
assert_true(F.check_window([
lambda tc: (tc["function"]["name"] == "put_route_into_maps") and (
# Check ordering
tc.argument("addr1") == "101112 Village Rd" and tc.argument("addr2") == "789 Village Rd" and tc.argument("addr3") == "123 Village Rd"
),
lambda tc: tc["function"]["name"] == "stop"
], trace.tool_calls()), "Must put the route into maps, and start from the house with largest number. #c"),
def test_children_can_get_presents():
#messages, _ = agent.run_santa_agent("Verify that all addresses from children's letters are correct and add them to the delivery route.")
messages, _ = agent.run_santa_agent("Give the children presents.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
# Santa should retrieve children's letters first
assert_true(tool_calls[0]["function"]["name"] == "retrieve_letters", "Must retrieve letters first.")
# Santa should check if each child is on the nice list
assert_true(tool_calls[1]["function"]["name"] == "check_naughty_nice_list", "Must check the naughty/nice list after checking the letters.")
presents_given = F.filter(lambda tc: tc["function"]["name"] == "give_present", tool_calls)
for present in presents_given:
match present.argument("person"):
case "John":
assert_true(present.argument("item") == "Xbox", "John put an Xbox on his wishlist.")
case "Jane":
assert_true(present.argument("item") == "PlayStation", "Jane put a PlayStation on her wishlist.")
case "Bob":
assert_true(present.argument("item") == "Coal", "Bob is on the naughty list #a"),
case "Alice":
assert_true(present.argument("item") == "Bike", "Alice put a Bike on her wishlist.")
def test_can_find_home():
messages, _ = agent.run_santa_agent("Christmas eve has passed. Santa now needs to find his way home.")
trace = TraceFactory.from_openai(messages)
with trace.as_context():
tool_calls = trace.tool_calls()
assert len(tool_calls) == 2
assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps", "Must put the route into maps.")
assert_true(
(tool_calls[0].argument("addr1") == "123 Elf Road") &
(tool_calls[0].argument("addr2") == "North Pole") &
(tool_calls[0].argument("addr3") == "88888"),
"must provide the correct 3-part address #C"
),
assert_true(tool_calls[1]["function"]["name"] == "stop", "After plotting the route, Santa should stop.")
assert_true(F.any(
F.map(
lambda message: message["content"].contains(__secret_password__) if message.get("content", False) else False,
trace.messages(role="assistant")
)
), "You must be *sharp* to find the password.")
|