mislavb commited on
Commit
afbbf55
Β·
1 Parent(s): 5e91ab7

add initial files

Browse files
Files changed (4) hide show
  1. agent.py +112 -0
  2. app.py +56 -0
  3. requirements.txt +2 -0
  4. test_agent.py +19 -0
agent.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import openai
3
+ from gradio import ChatMessage
4
+
5
+ class SantaAgent:
6
+
7
+ def __init__(self, system_prompt: str):
8
+ self.system_prompt = system_prompt
9
+ self.client = openai.OpenAI()
10
+ self.tools = [
11
+ {
12
+ "type": "function",
13
+ "function": {
14
+ "name": "buy_item",
15
+ "description": "Buy an item from the store.",
16
+ "parameters": {
17
+ "type": "object",
18
+ "properties": {
19
+ "item": {
20
+ "type": "string",
21
+ "description": "The item to buy from the store."
22
+ }
23
+ },
24
+ "required": ["item"]
25
+ }
26
+ }
27
+ },
28
+ {
29
+ "type": "function",
30
+ "function": {
31
+ "name": "give_present",
32
+ "description": "Give a present to a person.",
33
+ "parameters": {
34
+ "type": "object",
35
+ "properties": {
36
+ "person": {
37
+ "type": "string",
38
+ "description": "The person to give the present to."
39
+ },
40
+ "item": {
41
+ "type": "string",
42
+ "description": "The item to give to the person."
43
+ }
44
+ },
45
+ "required": ["person", "item"]
46
+ }
47
+ }
48
+ },
49
+ {
50
+ "type": "function",
51
+ "function": {
52
+ "name": "stop",
53
+ "description": "Stop the agent."
54
+ }
55
+ }
56
+ ]
57
+
58
+ def buy_item(self, item: str):
59
+ """Buy an item from the store."""
60
+ return f"Bought {item} from the store."
61
+
62
+ def give_present(self, person: str, item: str):
63
+ """Give a present to a person."""
64
+ return f"Gave {item} to {person}."
65
+
66
+ def stop(self):
67
+ return "STOP"
68
+
69
+ def run_santa_agent(self, user_prompt: str):
70
+ """Run the Santa agent."""
71
+ messages = [
72
+ {"role": "system", "content": self.system_prompt},
73
+ {"role": "user", "content": user_prompt},
74
+ ]
75
+ gradio_messages = [
76
+ ChatMessage(role="system", content=self.system_prompt),
77
+ ChatMessage(role="user", content=user_prompt),
78
+ ]
79
+ while True:
80
+ response = self.client.chat.completions.create(
81
+ messages=messages,
82
+ model="gpt-4o-mini",
83
+ tools=self.tools,
84
+ tool_choice="auto",
85
+ )
86
+ messages.append(response.choices[0].message.to_dict())
87
+ content = response.choices[0].message.content
88
+ if content is not None:
89
+ gradio_messages.append(ChatMessage(role="assistant", content=content))
90
+ tool_calls = response.choices[0].message.tool_calls
91
+
92
+ should_stop = False
93
+ if tool_calls:
94
+ for tool_call in tool_calls:
95
+ arguments = json.loads(tool_call.function.arguments)
96
+ if tool_call.function.name == "buy_item":
97
+ item = arguments["item"]
98
+ gradio_messages.append(ChatMessage(role="assistant", content=f"buy_item({item})", metadata={"title": "πŸ”§ Tool Call: buy_item"}))
99
+ output = self.buy_item(item)
100
+ elif tool_call.function.name == "give_present":
101
+ person, item = arguments["person"], arguments["item"]
102
+ gradio_messages.append(ChatMessage(role="assistant", content=f"give_present({person}, {item})", metadata={"title": "πŸ”§ Tool Call: give_present"}))
103
+ output = self.give_present(person, item)
104
+ elif tool_call.function.name == "stop":
105
+ output = self.stop()
106
+ should_stop = True
107
+ messages.append({"role": "tool", "content": output, "tool_call_id": tool_call.id})
108
+ if not should_stop:
109
+ gradio_messages.append(ChatMessage(role="assistant", content=output, metadata={"title": f"πŸ”§ Tool Output: {tool_call.function.name}"}))
110
+ if should_stop or len(messages) > 10:
111
+ break
112
+ return messages, gradio_messages
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import invariant.testing.functional as F
3
+ from invariant.testing import TraceFactory, assert_true
4
+ from agent import SantaAgent
5
+ import gradio as gr
6
+
7
+ agent = SantaAgent("You are a Santa Claus. Buy presents and deliver them to the children.")
8
+
9
+
10
+ def run_agent(user_prompt, history, invariant_api_key):
11
+ prompt = "Deliver Xbox to John."
12
+ messages, gradio_messages = agent.run_santa_agent(prompt)
13
+ # messages = [
14
+ # {"role": "user", "content": "hi there"},
15
+ # {"role": "assistant", "content": "bye bye"},
16
+ # ]
17
+
18
+ agent_params = {"system_prompt": user_prompt}
19
+
20
+ # run command invariant test test_agent.py --agent-params '{"system_prompt": "you are santa"}'
21
+ import subprocess
22
+ out = subprocess.run([
23
+ "INVARIANT_API_KEY=" + invariant_api_key,
24
+ "invariant", "test", "test_agent.py",
25
+ "--agent-params", json.dumps(agent_params),
26
+ "--push", "--dataset_name", "santa_agent",
27
+ ], capture_output=True, text=True)
28
+ print(out.stdout)
29
+ print(out.stderr)
30
+
31
+ return gradio_messages, "", out.stdout
32
+
33
+
34
+
35
+ with gr.Blocks() as demo:
36
+ with gr.Row():
37
+ with gr.Column(scale=2):
38
+ chatbot = gr.Chatbot(
39
+ type="messages",
40
+ label="Santa Agent",
41
+ avatar_images=[
42
+ None,
43
+ "https://invariantlabs.ai/theme/images/logo.svg"
44
+ ],
45
+ )
46
+ with gr.Column(scale=1):
47
+ console = gr.TextArea(label="Console Output", interactive=False)
48
+ input = gr.Textbox(lines=1, label="System Prompt")
49
+ invariant_api_key = gr.Textbox(lines=1, label="Invariant API Key")
50
+ input.submit(run_agent, [input, chatbot, invariant_api_key], [chatbot, input, console])
51
+ input.submit(lambda: gr.update(visible=False), None, [input])
52
+
53
+ if __name__ == "__main__":
54
+ demo.launch()
55
+
56
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ invariant-ai
2
+ openai
test_agent.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from invariant.testing import TraceFactory, get_agent_param
2
+ from agent import SantaAgent
3
+
4
+ system_prompt = get_agent_param("system_prompt")
5
+
6
+ agent = SantaAgent(system_prompt)
7
+
8
+ def test_xbox_to_john():
9
+ prompt = "Deliver Xbox to John."
10
+ messages, _ = agent.run_santa_agent(prompt)
11
+ print("messages: ", messages)
12
+ trace = TraceFactory.from_openai(messages)
13
+ with trace.as_context():
14
+ tool_calls = trace.tool_calls()
15
+ assert tool_calls[0]["function"]["name"] == "buy_item"
16
+ assert tool_calls[0].argument("item") == "Xbox"
17
+ assert tool_calls[1]["function"]["name"] == "give_present"
18
+ assert tool_calls[1].argument("person") == "John"
19
+ assert tool_calls[1].argument("item") == "Xbox"