add initial files
Browse files- agent.py +112 -0
- app.py +56 -0
- requirements.txt +2 -0
- test_agent.py +19 -0
agent.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import openai
|
3 |
+
from gradio import ChatMessage
|
4 |
+
|
5 |
+
class SantaAgent:
|
6 |
+
|
7 |
+
def __init__(self, system_prompt: str):
|
8 |
+
self.system_prompt = system_prompt
|
9 |
+
self.client = openai.OpenAI()
|
10 |
+
self.tools = [
|
11 |
+
{
|
12 |
+
"type": "function",
|
13 |
+
"function": {
|
14 |
+
"name": "buy_item",
|
15 |
+
"description": "Buy an item from the store.",
|
16 |
+
"parameters": {
|
17 |
+
"type": "object",
|
18 |
+
"properties": {
|
19 |
+
"item": {
|
20 |
+
"type": "string",
|
21 |
+
"description": "The item to buy from the store."
|
22 |
+
}
|
23 |
+
},
|
24 |
+
"required": ["item"]
|
25 |
+
}
|
26 |
+
}
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"type": "function",
|
30 |
+
"function": {
|
31 |
+
"name": "give_present",
|
32 |
+
"description": "Give a present to a person.",
|
33 |
+
"parameters": {
|
34 |
+
"type": "object",
|
35 |
+
"properties": {
|
36 |
+
"person": {
|
37 |
+
"type": "string",
|
38 |
+
"description": "The person to give the present to."
|
39 |
+
},
|
40 |
+
"item": {
|
41 |
+
"type": "string",
|
42 |
+
"description": "The item to give to the person."
|
43 |
+
}
|
44 |
+
},
|
45 |
+
"required": ["person", "item"]
|
46 |
+
}
|
47 |
+
}
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"type": "function",
|
51 |
+
"function": {
|
52 |
+
"name": "stop",
|
53 |
+
"description": "Stop the agent."
|
54 |
+
}
|
55 |
+
}
|
56 |
+
]
|
57 |
+
|
58 |
+
def buy_item(self, item: str):
|
59 |
+
"""Buy an item from the store."""
|
60 |
+
return f"Bought {item} from the store."
|
61 |
+
|
62 |
+
def give_present(self, person: str, item: str):
|
63 |
+
"""Give a present to a person."""
|
64 |
+
return f"Gave {item} to {person}."
|
65 |
+
|
66 |
+
def stop(self):
|
67 |
+
return "STOP"
|
68 |
+
|
69 |
+
def run_santa_agent(self, user_prompt: str):
|
70 |
+
"""Run the Santa agent."""
|
71 |
+
messages = [
|
72 |
+
{"role": "system", "content": self.system_prompt},
|
73 |
+
{"role": "user", "content": user_prompt},
|
74 |
+
]
|
75 |
+
gradio_messages = [
|
76 |
+
ChatMessage(role="system", content=self.system_prompt),
|
77 |
+
ChatMessage(role="user", content=user_prompt),
|
78 |
+
]
|
79 |
+
while True:
|
80 |
+
response = self.client.chat.completions.create(
|
81 |
+
messages=messages,
|
82 |
+
model="gpt-4o-mini",
|
83 |
+
tools=self.tools,
|
84 |
+
tool_choice="auto",
|
85 |
+
)
|
86 |
+
messages.append(response.choices[0].message.to_dict())
|
87 |
+
content = response.choices[0].message.content
|
88 |
+
if content is not None:
|
89 |
+
gradio_messages.append(ChatMessage(role="assistant", content=content))
|
90 |
+
tool_calls = response.choices[0].message.tool_calls
|
91 |
+
|
92 |
+
should_stop = False
|
93 |
+
if tool_calls:
|
94 |
+
for tool_call in tool_calls:
|
95 |
+
arguments = json.loads(tool_call.function.arguments)
|
96 |
+
if tool_call.function.name == "buy_item":
|
97 |
+
item = arguments["item"]
|
98 |
+
gradio_messages.append(ChatMessage(role="assistant", content=f"buy_item({item})", metadata={"title": "π§ Tool Call: buy_item"}))
|
99 |
+
output = self.buy_item(item)
|
100 |
+
elif tool_call.function.name == "give_present":
|
101 |
+
person, item = arguments["person"], arguments["item"]
|
102 |
+
gradio_messages.append(ChatMessage(role="assistant", content=f"give_present({person}, {item})", metadata={"title": "π§ Tool Call: give_present"}))
|
103 |
+
output = self.give_present(person, item)
|
104 |
+
elif tool_call.function.name == "stop":
|
105 |
+
output = self.stop()
|
106 |
+
should_stop = True
|
107 |
+
messages.append({"role": "tool", "content": output, "tool_call_id": tool_call.id})
|
108 |
+
if not should_stop:
|
109 |
+
gradio_messages.append(ChatMessage(role="assistant", content=output, metadata={"title": f"π§ Tool Output: {tool_call.function.name}"}))
|
110 |
+
if should_stop or len(messages) > 10:
|
111 |
+
break
|
112 |
+
return messages, gradio_messages
|
app.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import invariant.testing.functional as F
|
3 |
+
from invariant.testing import TraceFactory, assert_true
|
4 |
+
from agent import SantaAgent
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
agent = SantaAgent("You are a Santa Claus. Buy presents and deliver them to the children.")
|
8 |
+
|
9 |
+
|
10 |
+
def run_agent(user_prompt, history, invariant_api_key):
|
11 |
+
prompt = "Deliver Xbox to John."
|
12 |
+
messages, gradio_messages = agent.run_santa_agent(prompt)
|
13 |
+
# messages = [
|
14 |
+
# {"role": "user", "content": "hi there"},
|
15 |
+
# {"role": "assistant", "content": "bye bye"},
|
16 |
+
# ]
|
17 |
+
|
18 |
+
agent_params = {"system_prompt": user_prompt}
|
19 |
+
|
20 |
+
# run command invariant test test_agent.py --agent-params '{"system_prompt": "you are santa"}'
|
21 |
+
import subprocess
|
22 |
+
out = subprocess.run([
|
23 |
+
"INVARIANT_API_KEY=" + invariant_api_key,
|
24 |
+
"invariant", "test", "test_agent.py",
|
25 |
+
"--agent-params", json.dumps(agent_params),
|
26 |
+
"--push", "--dataset_name", "santa_agent",
|
27 |
+
], capture_output=True, text=True)
|
28 |
+
print(out.stdout)
|
29 |
+
print(out.stderr)
|
30 |
+
|
31 |
+
return gradio_messages, "", out.stdout
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
with gr.Blocks() as demo:
|
36 |
+
with gr.Row():
|
37 |
+
with gr.Column(scale=2):
|
38 |
+
chatbot = gr.Chatbot(
|
39 |
+
type="messages",
|
40 |
+
label="Santa Agent",
|
41 |
+
avatar_images=[
|
42 |
+
None,
|
43 |
+
"https://invariantlabs.ai/theme/images/logo.svg"
|
44 |
+
],
|
45 |
+
)
|
46 |
+
with gr.Column(scale=1):
|
47 |
+
console = gr.TextArea(label="Console Output", interactive=False)
|
48 |
+
input = gr.Textbox(lines=1, label="System Prompt")
|
49 |
+
invariant_api_key = gr.Textbox(lines=1, label="Invariant API Key")
|
50 |
+
input.submit(run_agent, [input, chatbot, invariant_api_key], [chatbot, input, console])
|
51 |
+
input.submit(lambda: gr.update(visible=False), None, [input])
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
demo.launch()
|
55 |
+
|
56 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
invariant-ai
|
2 |
+
openai
|
test_agent.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from invariant.testing import TraceFactory, get_agent_param
|
2 |
+
from agent import SantaAgent
|
3 |
+
|
4 |
+
system_prompt = get_agent_param("system_prompt")
|
5 |
+
|
6 |
+
agent = SantaAgent(system_prompt)
|
7 |
+
|
8 |
+
def test_xbox_to_john():
|
9 |
+
prompt = "Deliver Xbox to John."
|
10 |
+
messages, _ = agent.run_santa_agent(prompt)
|
11 |
+
print("messages: ", messages)
|
12 |
+
trace = TraceFactory.from_openai(messages)
|
13 |
+
with trace.as_context():
|
14 |
+
tool_calls = trace.tool_calls()
|
15 |
+
assert tool_calls[0]["function"]["name"] == "buy_item"
|
16 |
+
assert tool_calls[0].argument("item") == "Xbox"
|
17 |
+
assert tool_calls[1]["function"]["name"] == "give_present"
|
18 |
+
assert tool_calls[1].argument("person") == "John"
|
19 |
+
assert tool_calls[1].argument("item") == "Xbox"
|