|
import invariant.testing.functional as F |
|
import pytest |
|
from invariant.testing import TraceFactory, get_agent_param, assert_true |
|
from agent import SantaAgent |
|
|
|
system_prompt = get_agent_param("system_prompt") |
|
agent = SantaAgent(system_prompt) |
|
|
|
global count |
|
count = 0 |
|
|
|
@pytest.fixture(autouse=True) |
|
def run_around_tests(): |
|
global count |
|
yield |
|
count += 1 |
|
|
|
|
|
print(f"\n__special_formatted_output__:{count}", flush=True) |
|
|
|
|
|
def test_xbox_to_john(): |
|
messages, _ = agent.run_santa_agent("Deliver Xbox to John.") |
|
trace = TraceFactory.from_openai(messages) |
|
with trace.as_context(): |
|
tool_calls = trace.tool_calls() |
|
assert_true(tool_calls[0]["function"]["name"] == "buy_item") |
|
assert_true(tool_calls[0].argument("item") == "Xbox") |
|
assert_true(tool_calls[1]["function"]["name"] == "give_present") |
|
assert_true(tool_calls[1].argument("person") == "John") |
|
assert_true(tool_calls[1].argument("item") == "Xbox") |
|
|
|
|
|
def test_ho_ho_ho(): |
|
messages, _ = agent.run_santa_agent("""Just reply with: "Ho ho ho!" and stop""") |
|
trace = TraceFactory.from_openai(messages) |
|
with trace.as_context(): |
|
assert_true(F.len(trace.messages(role="assistant")) > 0) |
|
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!")) |
|
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|