File size: 1,846 Bytes
d6bfc1a
 
679197d
afbbf55
 
 
 
 
f7b79e2
 
 
 
 
 
 
 
 
 
 
 
 
afbbf55
679197d
afbbf55
 
 
679197d
 
 
 
 
 
 
 
6b5e356
679197d
 
d6bfc1a
679197d
 
d6bfc1a
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import invariant.testing.functional as F
import pytest
from invariant.testing import TraceFactory, get_agent_param, assert_true
from agent import SantaAgent

system_prompt = get_agent_param("system_prompt")
agent = SantaAgent(system_prompt)

global count
count = 0

@pytest.fixture(autouse=True)
def run_around_tests():
    global count
    yield
    count += 1

    # Output the number of tests done
    print(f"\n__special_formatted_output__:{count}", flush=True)


def test_xbox_to_john():
    messages, _ = agent.run_santa_agent("Deliver Xbox to John.")
    trace = TraceFactory.from_openai(messages)
    with trace.as_context():
        tool_calls = trace.tool_calls()
        assert_true(tool_calls[0]["function"]["name"] == "buy_item")
        assert_true(tool_calls[0].argument("item") == "Xbox")
        assert_true(tool_calls[1]["function"]["name"] == "give_present")
        assert_true(tool_calls[1].argument("person") == "John")
        assert_true(tool_calls[1].argument("item") == "Xbox")


def test_ho_ho_ho():
    messages, _ = agent.run_santa_agent("""Just reply with: "Ho ho ho!" and stop""")
    trace = TraceFactory.from_openai(messages)
    with trace.as_context():
        assert_true(F.len(trace.messages(role="assistant")) > 0)
        assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"))
        assert_true(trace.tool_calls()[0]["function"]["name"] == "stop")


# @pytest.mark.parametrize("country", ["Finland", "Iceland"])
# def test_cities(country):
#     messages, _ = agent.run_santa_agent(f"""Write a Christmas song that mentions exactly 5 cities in {country}.""")
#     trace = TraceFactory.from_openai(messages)
#     with trace.as_context():
#         cities = trace.messages(role="assistant")[0]["content"].extract(f"cities in {country}")
#         assert_true(F.len(cities) == 5)