File size: 9,291 Bytes
d6bfc1a
 
679197d
afbbf55
 
 
 
 
f7b79e2
 
 
c80f06e
 
 
f7b79e2
 
 
 
 
 
 
 
 
 
c80f06e
 
afbbf55
 
 
bcef048
c80f06e
 
 
bcef048
c80f06e
 
 
 
f448621
c80f06e
 
 
 
 
 
 
 
 
 
bcef048
679197d
 
 
9d58bb7
679197d
 
bcef048
 
f448621
 
 
 
d6bfc1a
 
c80f06e
9d58bb7
c80f06e
 
f448621
bcef048
0f652c7
 
 
bcef048
c80f06e
 
 
 
 
 
 
 
 
 
 
 
bcef048
c80f06e
bcef048
c80f06e
 
 
 
 
 
 
bcef048
c80f06e
 
 
 
 
 
 
 
 
 
 
 
bcef048
c80f06e
 
 
 
 
 
 
 
f448621
 
c80f06e
 
 
 
 
 
 
 
 
 
 
 
 
bcef048
 
c80f06e
 
 
 
 
 
 
bcef048
c80f06e
 
 
 
 
 
 
 
 
 
bcef048
c80f06e
 
bcef048
c80f06e
 
 
 
 
 
bcef048
c80f06e
bcef048
c80f06e
bcef048
c80f06e
bcef048
c80f06e
 
 
 
 
 
 
 
 
bcef048
9d58bb7
 
 
bcef048
 
 
 
c80f06e
 
 
 
 
 
bcef048
c80f06e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import invariant.testing.functional as F
import pytest
from invariant.testing import TraceFactory, get_agent_param, assert_true
from agent import SantaAgent

system_prompt = get_agent_param("system_prompt")
agent = SantaAgent(system_prompt)

global count
count = 0

__secret_password__ = "Candycanes"


@pytest.fixture(autouse=True)
def run_around_tests():
    global count
    yield
    count += 1

    # Output the number of tests done
    print(f"\n__special_formatted_output__:{count}", flush=True)


def test_make_naughty_nice_list():
    messages, _ = agent.run_santa_agent("Make a naughty and nice list.")
    trace = TraceFactory.from_openai(messages)
    with trace.as_context():
        tool_calls = trace.tool_calls()
        assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list", message="Must make a list! #s")
        check_list_count = 0
        for tc in tool_calls:
            check_list_count += 1 if tc["function"]["name"] == "check_naughty_nice_list" else 0
        assert_true(check_list_count == 2, message="Must check the list twice!")
        assert_true(F.check_order([
            lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
            lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
            lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
        ], tool_calls), "Must make the list, then check it twice.")


def test_present_to_john():
    messages, _ = agent.run_santa_agent("Deliver John his secret wish.")
    trace = TraceFactory.from_openai(messages)
    with trace.as_context():
        tool_calls = trace.tool_calls()
        assert_true(F.check_order([
            lambda tc: tc['function']['name'] == 'buy_item' and tc.argument("item") == "Football",
            lambda tc: tc['function']['name'] == 'give_present' and tc.argument("person") == "John" and tc.argument("item") == "Football",
        ], tool_calls), "John secretly wants a Football. Santa should first buy the present, then give it to John. #a"),


def test_ho_ho_ho():
    messages, _ = agent.run_santa_agent("""Say Ho ho ho! Then stop.""")
    trace = TraceFactory.from_openai(messages)
    with trace.as_context():
        assert_true(F.len(trace.messages(role="assistant")) > 0, "Santa must say something!")
        assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"), "Santa must say Ho ho ho! #y"),
        assert_true(
            len(trace.tool_calls()) == 0 or trace.tool_calls()[0]["function"]["name"] == "stop",
            "Santa must stop after saying Ho ho ho!"
        )


def test_reindeer_names():
    messages, _ = agent.run_santa_agent("Tell me all the reindeer names. Replace Blitzen with the name of the child who gets the Bike. Then stop.")
    trace = TraceFactory.from_openai(messages)
    with trace.as_context():
        assert_true(F.len(trace.tool_calls()) == 1, "Santa should only call the stop tool."),
        assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa should only call the stop tool. #n"),
        assert_true(trace.messages(role="assistant")[0]["content"].contains(
            "Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"
        ), "Santa must list all the reindeer names, and Alice"),
        assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"), "Santa must replace Blitzen with Alice."),


def test_wrap_presents():
    messages, _ = agent.run_santa_agent("One of the elves is off sick. Wrap the PlayStation for them.")
    trace = TraceFactory.from_openai(messages)
    with trace.as_context():
        tool_calls = trace.tool_calls()
        assert_true(F.check_window([
            lambda tc: tc["function"]["name"] == "cut_paper",
            lambda tc: tc["function"]["name"] == "find_end_of_tape",
            lambda tc: tc["function"]["name"] == "wrap_present",
            lambda tc: tc["function"]["name"] == "label_present",
        ], tool_calls), message="When wrapping a present, Sant should: 1) cut paper, 2) find end of tape, 3) wrap present, 4) label present.")

        assert_true(tool_calls[3].argument("recipient") == "Jane", message="Jane wants the PlayStation. #d")


def test_write_christmas_song():
    messages, _ = agent.run_santa_agent("Write a Christmas song that mentions at least 5 cities in Switzerland. Then stop.")
    trace = TraceFactory.from_openai(messages)
    with trace.as_context():
        cities = trace.messages(role="assistant")[0]["content"].extract("cities in Switzerland")
        assert_true(F.len(cities) >= 5, "Must mention at least 5 cities #e"),


def test_drink_milk():
    messages, _ = agent.run_santa_agent("Drink a glass of milk.")
    trace = TraceFactory.from_openai(messages)
    with trace.as_context():
        tool_calls = trace.tool_calls()

        assert_true(F.check_order([
            lambda tc: tc["function"]["name"] == "check_temperature",
            lambda tc: tc["function"]["name"] == "dunk_cookie",
            lambda tc: (tc["function"]["name"] == "drink") & (tc.argument("item") == "milk"),
        ], tool_calls), "Must check temperature, dunk cookie, then drink milk. #n"),

        def check_messages(message):
            if 'content' not in message or message["content"] is None:
                return False
            return message["content"].contains("Ho ho ho!")

        assert_true(F.any(F.map(
            check_messages, 
            trace.messages(role="assistant"),
        )), "Santa must say Ho ho ho!")


def test_reindeer_flight_plan():
    messages, _ = agent.run_santa_agent("Plan the flight path for the reindeer.")
    trace = TraceFactory.from_openai(messages)
    with trace.as_context():
        # Santa should start by finding the addresses of the children
        message_contents = F.map(
            lambda message: message.get('content', None),
            trace.messages(role="assistant")
        )

        addrs = F.match("[0-9]+\s+[A-Za-z]+\s+Rd", message_contents)
        assert_true(addrs is not None, "Must find addresses.")
        assert_true(not F.any(F.map(lambda message: message.startswith('456'), addrs)), "Must not include 456 Village Rd (Bob is on the naughty list).")

        assert_true(F.check_window([
            lambda tc: (tc["function"]["name"] == "put_route_into_maps") and ( 
                # Check ordering
                tc.argument("addr1") == "101112 Village Rd" and tc.argument("addr2") == "789 Village Rd" and tc.argument("addr3") == "123 Village Rd"
            ),
            lambda tc: tc["function"]["name"] == "stop"
        ], trace.tool_calls()), "Must put the route into maps, and start from the house with largest number. #c"),


def test_children_can_get_presents():
    #messages, _ = agent.run_santa_agent("Verify that all addresses from children's letters are correct and add them to the delivery route.")
    messages, _ = agent.run_santa_agent("Give the children presents.")
    trace = TraceFactory.from_openai(messages)
    with trace.as_context():
        tool_calls = trace.tool_calls()

        # Santa should retrieve children's letters first
        assert_true(tool_calls[0]["function"]["name"] == "retrieve_letters", "Must retrieve letters first.")

        # Santa should check if each child is on the nice list
        assert_true(tool_calls[1]["function"]["name"] == "check_naughty_nice_list", "Must check the naughty/nice list after checking the letters.")

        presents_given = F.filter(lambda tc: tc["function"]["name"] == "give_present", tool_calls)

        for present in presents_given:
            match present.argument("person"):
                case "John":
                    assert_true(present.argument("item") == "Xbox", "John put an Xbox on his wishlist.")
                case "Jane":
                    assert_true(present.argument("item") == "PlayStation", "Jane put a PlayStation on her wishlist.")
                case "Bob":
                    assert_true(present.argument("item") == "Coal", "Bob is on the naughty list #a"), 
                case "Alice":
                    assert_true(present.argument("item") == "Bike", "Alice put a Bike on her wishlist.")


def test_can_find_home():
    messages, _ = agent.run_santa_agent("Christmas eve has passed. Santa now needs to find his way home.")
    trace = TraceFactory.from_openai(messages)
    with trace.as_context():
        tool_calls = trace.tool_calls()
        assert len(tool_calls) == 2

        assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps", "Must put the route into maps.")
        assert_true(
            (tool_calls[0].argument("addr1") == "123 Elf Road") & 
            (tool_calls[0].argument("addr2") == "North Pole")   & 
            (tool_calls[0].argument("addr3") == "88888"),
            "must provide the correct 3-part address #C"
        ), 
        assert_true(tool_calls[1]["function"]["name"] == "stop", "After plotting the route, Santa should stop.")

        assert_true(F.any(
            F.map(
                lambda message: message["content"].contains(__secret_password__) if message.get("content", False) else False,
                trace.messages(role="assistant")
            )
        ), "You must be *sharp* to find the password.")