wip
Browse files- agent.py +11 -0
- app.py +26 -14
- test_agent.py +15 -10
agent.py
CHANGED
@@ -66,6 +66,17 @@ class SantaAgent:
|
|
66 |
def stop(self):
|
67 |
return "STOP"
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def run_santa_agent(self, user_prompt: str):
|
70 |
"""Run the Santa agent."""
|
71 |
messages = [
|
|
|
66 |
def stop(self):
|
67 |
return "STOP"
|
68 |
|
69 |
+
def mock_run_santa_agent(self):
|
70 |
+
messages = [
|
71 |
+
{"role": "user", "content": "Hi there"},
|
72 |
+
{"role": "assistant", "content": "Bye bye"},
|
73 |
+
]
|
74 |
+
gradio_messages = [
|
75 |
+
ChatMessage(role="user", content="Hi there"),
|
76 |
+
ChatMessage(role="assistant", content="Bye bye"),
|
77 |
+
]
|
78 |
+
return messages, gradio_messages
|
79 |
+
|
80 |
def run_santa_agent(self, user_prompt: str):
|
81 |
"""Run the Santa agent."""
|
82 |
messages = [
|
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import json
|
2 |
import os
|
|
|
3 |
import invariant.testing.functional as F
|
4 |
from invariant.testing import TraceFactory, assert_true
|
5 |
from agent import SantaAgent
|
@@ -7,17 +8,15 @@ import gradio as gr
|
|
7 |
|
8 |
agent = SantaAgent("You are a Santa Claus. Buy presents and deliver them to the children.")
|
9 |
|
10 |
-
|
11 |
-
def run_agent(user_prompt, history, invariant_api_key):
|
12 |
prompt = "Deliver Xbox to John."
|
13 |
messages, gradio_messages = agent.run_santa_agent(prompt)
|
14 |
-
# messages = [
|
15 |
-
# {"role": "user", "content": "hi there"},
|
16 |
-
# {"role": "assistant", "content": "bye bye"},
|
17 |
-
# ]
|
18 |
|
19 |
-
|
|
|
20 |
|
|
|
21 |
|
22 |
env={
|
23 |
"INVARIANT_API_KEY": invariant_api_key,
|
@@ -33,19 +32,27 @@ def run_agent(user_prompt, history, invariant_api_key):
|
|
33 |
"--agent-params", json.dumps(agent_params),
|
34 |
"--push", "--dataset_name", "santa_agent",
|
35 |
], capture_output=True, text=True, env=env)
|
36 |
-
print(out.stdout)
|
37 |
-
print(out.stderr)
|
38 |
-
|
39 |
-
return gradio_messages, "", out.stdout
|
40 |
|
|
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
44 |
with gr.Row():
|
45 |
with gr.Column(scale=2):
|
46 |
chatbot = gr.Chatbot(
|
47 |
type="messages",
|
48 |
-
label="
|
|
|
|
|
|
|
49 |
avatar_images=[
|
50 |
None,
|
51 |
"https://invariantlabs.ai/theme/images/logo.svg"
|
@@ -53,12 +60,17 @@ with gr.Blocks() as demo:
|
|
53 |
)
|
54 |
with gr.Column(scale=1):
|
55 |
console = gr.TextArea(label="Console Output", interactive=False)
|
56 |
-
|
57 |
invariant_api_key = gr.Textbox(lines=1, label="Invariant API Key")
|
58 |
input.submit(run_agent, [input, chatbot, invariant_api_key], [chatbot, input, console])
|
59 |
input.submit(lambda: gr.update(visible=False), None, [input])
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
if __name__ == "__main__":
|
62 |
demo.launch()
|
63 |
|
64 |
-
|
|
|
1 |
import json
|
2 |
import os
|
3 |
+
import re
|
4 |
import invariant.testing.functional as F
|
5 |
from invariant.testing import TraceFactory, assert_true
|
6 |
from agent import SantaAgent
|
|
|
8 |
|
9 |
agent = SantaAgent("You are a Santa Claus. Buy presents and deliver them to the children.")
|
10 |
|
11 |
+
|
12 |
+
def run_agent(user_prompt, history, invariant_api_key, is_example=False):
|
13 |
prompt = "Deliver Xbox to John."
|
14 |
messages, gradio_messages = agent.run_santa_agent(prompt)
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
if not invariant_api_key.startswith("inv"):
|
17 |
+
return gradio_messages, "", "Please enter a valid Invariant API key to get the score!"
|
18 |
|
19 |
+
agent_params = {"system_prompt": user_prompt}
|
20 |
|
21 |
env={
|
22 |
"INVARIANT_API_KEY": invariant_api_key,
|
|
|
32 |
"--agent-params", json.dumps(agent_params),
|
33 |
"--push", "--dataset_name", "santa_agent",
|
34 |
], capture_output=True, text=True, env=env)
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
url = re.search(r"https://explorer.invariantlabs.ai/[\-_a-zA-Z0-9/]+", out.stdout).group(0)
|
37 |
+
|
38 |
+
message = "Please find your results at: " + url
|
39 |
+
return gradio_messages, "", message
|
40 |
|
41 |
|
42 |
with gr.Blocks() as demo:
|
43 |
+
gr.Markdown("""
|
44 |
+
## Prompt the Santa Agent
|
45 |
+
* Find a system prompt that delivers the presents to the children
|
46 |
+
""")
|
47 |
+
input = gr.Textbox(lines=1, label="""System Prompt""", value="You are a Santa Claus. Buy presents and deliver them to the children.")
|
48 |
with gr.Row():
|
49 |
with gr.Column(scale=2):
|
50 |
chatbot = gr.Chatbot(
|
51 |
type="messages",
|
52 |
+
label="Example interaction",
|
53 |
+
value=[
|
54 |
+
{"role": "user", "content": "Could you please deliver Xbox to John?"},
|
55 |
+
],
|
56 |
avatar_images=[
|
57 |
None,
|
58 |
"https://invariantlabs.ai/theme/images/logo.svg"
|
|
|
60 |
)
|
61 |
with gr.Column(scale=1):
|
62 |
console = gr.TextArea(label="Console Output", interactive=False)
|
63 |
+
|
64 |
invariant_api_key = gr.Textbox(lines=1, label="Invariant API Key")
|
65 |
input.submit(run_agent, [input, chatbot, invariant_api_key], [chatbot, input, console])
|
66 |
input.submit(lambda: gr.update(visible=False), None, [input])
|
67 |
|
68 |
+
# Submit button
|
69 |
+
submit = gr.Button("Submit")
|
70 |
+
submit.click(run_agent, [input, chatbot, invariant_api_key], [chatbot, input, console])
|
71 |
+
submit.click(lambda: gr.update(visible=False), None, [input])
|
72 |
+
|
73 |
+
|
74 |
if __name__ == "__main__":
|
75 |
demo.launch()
|
76 |
|
|
test_agent.py
CHANGED
@@ -1,19 +1,24 @@
|
|
1 |
-
from invariant.testing import TraceFactory, get_agent_param
|
2 |
from agent import SantaAgent
|
3 |
|
4 |
system_prompt = get_agent_param("system_prompt")
|
5 |
-
|
6 |
agent = SantaAgent(system_prompt)
|
7 |
|
8 |
def test_xbox_to_john():
|
9 |
-
|
10 |
-
messages, _ = agent.run_santa_agent(prompt)
|
11 |
-
print("messages: ", messages)
|
12 |
trace = TraceFactory.from_openai(messages)
|
13 |
with trace.as_context():
|
14 |
tool_calls = trace.tool_calls()
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from invariant.testing import TraceFactory, get_agent_param, assert_true
|
2 |
from agent import SantaAgent
|
3 |
|
4 |
system_prompt = get_agent_param("system_prompt")
|
|
|
5 |
agent = SantaAgent(system_prompt)
|
6 |
|
7 |
def test_xbox_to_john():
|
8 |
+
messages, _ = agent.run_santa_agent("Deliver Xbox to John.")
|
|
|
|
|
9 |
trace = TraceFactory.from_openai(messages)
|
10 |
with trace.as_context():
|
11 |
tool_calls = trace.tool_calls()
|
12 |
+
assert_true(tool_calls[0]["function"]["name"] == "buy_item")
|
13 |
+
assert_true(tool_calls[0].argument("item") == "Xbox")
|
14 |
+
assert_true(tool_calls[1]["function"]["name"] == "give_present")
|
15 |
+
assert_true(tool_calls[1].argument("person") == "John")
|
16 |
+
assert_true(tool_calls[1].argument("item") == "Xbox")
|
17 |
+
|
18 |
+
|
19 |
+
def test_ho_ho_ho():
|
20 |
+
messages, _ = agent.run_santa_agent("Ho ho ho!")
|
21 |
+
trace = TraceFactory.from_openai(messages)
|
22 |
+
with trace.as_context():
|
23 |
+
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"))
|
24 |
+
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop")
|