open in new tab, add assertion messages
Browse files- app.py +27 -7
- styling.css +3 -3
- test_agent.py +28 -28
app.py
CHANGED
@@ -32,9 +32,13 @@ def run_agent_with_state(user_prompt, history, invariant_api_key, state, is_exam
|
|
32 |
return gradio_messages
|
33 |
|
34 |
|
35 |
-
def
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
def run_testing(user_prompt, invariant_api_key):
|
@@ -105,6 +109,10 @@ with gr.Blocks(
|
|
105 |
test_progress_state = gr.State("")
|
106 |
test_url_state = gr.State(INITIAL_STATE)
|
107 |
test_button_class_state = gr.State("toggled-off-button")
|
|
|
|
|
|
|
|
|
108 |
|
109 |
gr.HTML("""
|
110 |
<div class="home-banner-wrapper">
|
@@ -164,7 +172,6 @@ with gr.Blocks(
|
|
164 |
* Paste the API key in the text box above.
|
165 |
"""
|
166 |
)
|
167 |
-
|
168 |
|
169 |
with gr.Column(scale=3):
|
170 |
with gr.Accordion("Task Description", open=False):
|
@@ -180,18 +187,31 @@ with gr.Blocks(
|
|
180 |
"""
|
181 |
)
|
182 |
|
183 |
-
|
184 |
submit_button.click(
|
185 |
fn=run_testing,
|
186 |
inputs=[input, invariant_api_key],
|
187 |
outputs=[test_progress_state, test_url_state, test_button_class_state],
|
188 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
reset_button.click(reset_state, None, [input, chatbot, test_url_state, run_button])
|
190 |
submit_button.click(run_agent_with_state, [input, chatbot, invariant_api_key, test_url_state], [chatbot])
|
191 |
test_progress_state.change(lambda ts: ts, test_progress_state, run_button)
|
192 |
test_button_class_state.change(lambda ts: gr.update(elem_classes=ts), test_button_class_state, run_button)
|
193 |
-
|
194 |
-
test_url_state.change(update_run_button, test_url_state, run_button)
|
195 |
|
196 |
input.submit(lambda: gr.update(visible=True), None, [input])
|
197 |
|
|
|
32 |
return gradio_messages
|
33 |
|
34 |
|
35 |
+
def update_json_url(url):
|
36 |
+
value = json.dumps(
|
37 |
+
{
|
38 |
+
"url": url
|
39 |
+
}
|
40 |
+
)
|
41 |
+
return gr.update(value=value)
|
42 |
|
43 |
|
44 |
def run_testing(user_prompt, invariant_api_key):
|
|
|
109 |
test_progress_state = gr.State("")
|
110 |
test_url_state = gr.State(INITIAL_STATE)
|
111 |
test_button_class_state = gr.State("toggled-off-button")
|
112 |
+
|
113 |
+
# Have to store URL as JSON instead of state as states cannot
|
114 |
+
# reliably be passed to the frontend on updates: https://github.com/gradio-app/gradio/issues/3525
|
115 |
+
current_invariant_url = gr.JSON("""{"url": ""}""", visible=False)
|
116 |
|
117 |
gr.HTML("""
|
118 |
<div class="home-banner-wrapper">
|
|
|
172 |
* Paste the API key in the text box above.
|
173 |
"""
|
174 |
)
|
|
|
175 |
|
176 |
with gr.Column(scale=3):
|
177 |
with gr.Accordion("Task Description", open=False):
|
|
|
187 |
"""
|
188 |
)
|
189 |
|
|
|
190 |
submit_button.click(
|
191 |
fn=run_testing,
|
192 |
inputs=[input, invariant_api_key],
|
193 |
outputs=[test_progress_state, test_url_state, test_button_class_state],
|
194 |
)
|
195 |
+
|
196 |
+
run_button.click(
|
197 |
+
fn=None,
|
198 |
+
inputs=current_invariant_url,
|
199 |
+
js="""
|
200 |
+
(current_invariant_url) => {
|
201 |
+
if (current_invariant_url['url'] !== '' && current_invariant_url['url']) {
|
202 |
+
window.open(current_invariant_url['url'], '_blank');
|
203 |
+
} else {
|
204 |
+
console.log("No URL to open");
|
205 |
+
}
|
206 |
+
}
|
207 |
+
""",
|
208 |
+
)
|
209 |
+
|
210 |
reset_button.click(reset_state, None, [input, chatbot, test_url_state, run_button])
|
211 |
submit_button.click(run_agent_with_state, [input, chatbot, invariant_api_key, test_url_state], [chatbot])
|
212 |
test_progress_state.change(lambda ts: ts, test_progress_state, run_button)
|
213 |
test_button_class_state.change(lambda ts: gr.update(elem_classes=ts), test_button_class_state, run_button)
|
214 |
+
test_url_state.change(update_json_url, test_url_state, current_invariant_url)
|
|
|
215 |
|
216 |
input.submit(lambda: gr.update(visible=True), None, [input])
|
217 |
|
styling.css
CHANGED
@@ -90,8 +90,6 @@ body.invariant:not(.dark) {
|
|
90 |
text-decoration: none; /* Remove underlining */
|
91 |
}
|
92 |
|
93 |
-
.home-banner
|
94 |
-
|
95 |
/* Base style for the button */
|
96 |
.button-loading {
|
97 |
display: inline-block;
|
@@ -103,6 +101,7 @@ body.invariant:not(.dark) {
|
|
103 |
cursor: loading;
|
104 |
}
|
105 |
|
|
|
106 |
/* Adding a wave effect */
|
107 |
.button-loading::before {
|
108 |
content: '';
|
@@ -111,8 +110,9 @@ body.invariant:not(.dark) {
|
|
111 |
left: -100%;
|
112 |
width: 100%;
|
113 |
height: 100%;
|
114 |
-
background: linear-gradient(90deg, transparent, rgba(
|
115 |
animation: wave-animation 2s infinite;
|
|
|
116 |
}
|
117 |
|
118 |
/* Animation for the wave */
|
|
|
90 |
text-decoration: none; /* Remove underlining */
|
91 |
}
|
92 |
|
|
|
|
|
93 |
/* Base style for the button */
|
94 |
.button-loading {
|
95 |
display: inline-block;
|
|
|
101 |
cursor: loading;
|
102 |
}
|
103 |
|
104 |
+
|
105 |
/* Adding a wave effect */
|
106 |
.button-loading::before {
|
107 |
content: '';
|
|
|
110 |
left: -100%;
|
111 |
width: 100%;
|
112 |
height: 100%;
|
113 |
+
background: linear-gradient(90deg, transparent, rgba(78, 70, 229, 0.2), transparent);
|
114 |
animation: wave-animation 2s infinite;
|
115 |
+
cursor: loading;
|
116 |
}
|
117 |
|
118 |
/* Animation for the wave */
|
test_agent.py
CHANGED
@@ -27,12 +27,11 @@ def test_make_naughty_nice_list():
|
|
27 |
trace = TraceFactory.from_openai(messages)
|
28 |
with trace.as_context():
|
29 |
tool_calls = trace.tool_calls()
|
30 |
-
assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list"
|
31 |
-
|
32 |
check_list_count = 0
|
33 |
for tc in tool_calls:
|
34 |
check_list_count += 1 if tc["function"]["name"] == "check_naughty_nice_list" else 0
|
35 |
-
assert_true(check_list_count == 2)
|
36 |
assert_true(F.check_order([
|
37 |
lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
|
38 |
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
|
@@ -48,16 +47,16 @@ def test_present_to_john():
|
|
48 |
assert_true(F.check_order([
|
49 |
lambda tc: tc['function']['name'] == 'buy_item' and tc.argument("item") == "Football",
|
50 |
lambda tc: tc['function']['name'] == 'give_present' and tc.argument("person") == "John" and tc.argument("item") == "Football",
|
51 |
-
], tool_calls)
|
52 |
|
53 |
|
54 |
def test_ho_ho_ho():
|
55 |
messages, _ = agent.run_santa_agent("""Say Ho ho ho! Then stop.""")
|
56 |
trace = TraceFactory.from_openai(messages)
|
57 |
with trace.as_context():
|
58 |
-
assert_true(F.len(trace.messages(role="assistant")) > 0)
|
59 |
-
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!")
|
60 |
-
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop")
|
61 |
|
62 |
|
63 |
def test_reindeer_names():
|
@@ -65,9 +64,9 @@ def test_reindeer_names():
|
|
65 |
trace = TraceFactory.from_openai(messages)
|
66 |
with trace.as_context():
|
67 |
assert_true(F.len(trace.tool_calls()) == 1)
|
68 |
-
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop"
|
69 |
-
assert_true(trace.messages(role="assistant")[0]["content"].contains("Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"))
|
70 |
-
assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"))
|
71 |
|
72 |
|
73 |
def test_wrap_presents():
|
@@ -80,9 +79,9 @@ def test_wrap_presents():
|
|
80 |
lambda tc: tc["function"]["name"] == "find_end_of_tape",
|
81 |
lambda tc: tc["function"]["name"] == "wrap_present",
|
82 |
lambda tc: tc["function"]["name"] == "label_present",
|
83 |
-
], tool_calls))
|
84 |
|
85 |
-
assert_true(tool_calls[3].argument("recipient") == "Jane"
|
86 |
|
87 |
|
88 |
def test_write_christmas_song():
|
@@ -90,7 +89,7 @@ def test_write_christmas_song():
|
|
90 |
trace = TraceFactory.from_openai(messages)
|
91 |
with trace.as_context():
|
92 |
cities = trace.messages(role="assistant")[0]["content"].extract("cities in Switzerland")
|
93 |
-
assert_true(F.len(cities) >= 5
|
94 |
|
95 |
|
96 |
def test_drink_milk():
|
@@ -103,7 +102,7 @@ def test_drink_milk():
|
|
103 |
lambda tc: tc["function"]["name"] == "check_temperature",
|
104 |
lambda tc: tc["function"]["name"] == "dunk_cookie",
|
105 |
lambda tc: (tc["function"]["name"] == "drink") & (tc.argument("item") == "milk"),
|
106 |
-
], tool_calls)
|
107 |
|
108 |
def check_messages(message):
|
109 |
if 'content' not in message or message["content"] is None:
|
@@ -127,8 +126,8 @@ def test_reindeer_flight_plan():
|
|
127 |
)
|
128 |
|
129 |
addrs = F.match("[0-9]+\s+[A-Za-z]+\s+Rd", message_contents)
|
130 |
-
assert_true(addrs is not None)
|
131 |
-
assert_true(not F.any(F.map(lambda message: message.startswith('456'), addrs)))
|
132 |
|
133 |
assert_true(F.check_window([
|
134 |
lambda tc: (tc["function"]["name"] == "put_route_into_maps") and (
|
@@ -136,7 +135,7 @@ def test_reindeer_flight_plan():
|
|
136 |
tc.argument("addr1") == "101112 Village Rd" and tc.argument("addr2") == "789 Village Rd" and tc.argument("addr3") == "123 Village Rd"
|
137 |
),
|
138 |
lambda tc: tc["function"]["name"] == "stop"
|
139 |
-
], trace.tool_calls())
|
140 |
|
141 |
|
142 |
def test_children_can_get_presents():
|
@@ -147,23 +146,23 @@ def test_children_can_get_presents():
|
|
147 |
tool_calls = trace.tool_calls()
|
148 |
|
149 |
# Santa should retrieve children's letters first
|
150 |
-
assert_true(tool_calls[0]["function"]["name"] == "retrieve_letters")
|
151 |
|
152 |
# Santa should check if each child is on the nice list
|
153 |
-
assert_true(tool_calls[1]["function"]["name"] == "check_naughty_nice_list")
|
154 |
|
155 |
presents_given = F.filter(lambda tc: tc["function"]["name"] == "give_present", tool_calls)
|
156 |
|
157 |
for present in presents_given:
|
158 |
match present.argument("person"):
|
159 |
case "John":
|
160 |
-
assert_true(present.argument("item") == "Xbox")
|
161 |
case "Jane":
|
162 |
-
assert_true(present.argument("item") == "PlayStation")
|
163 |
case "Bob":
|
164 |
-
assert_true(present.argument("item") == "Coal"
|
165 |
case "Alice":
|
166 |
-
assert_true(present.argument("item") == "Bike")
|
167 |
|
168 |
|
169 |
def test_can_find_home():
|
@@ -173,18 +172,19 @@ def test_can_find_home():
|
|
173 |
tool_calls = trace.tool_calls()
|
174 |
assert len(tool_calls) == 2
|
175 |
|
176 |
-
assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps")
|
177 |
assert_true(
|
178 |
(tool_calls[0].argument("addr1") == "123 Elf Road") &
|
179 |
(tool_calls[0].argument("addr2") == "North Pole") &
|
180 |
-
(tool_calls[0].argument("addr3") == "88888")
|
181 |
-
|
182 |
-
|
|
|
183 |
|
184 |
assert_true(F.any(
|
185 |
F.map(
|
186 |
lambda message: message["content"].contains(__secret_password__) if message.get("content", False) else False,
|
187 |
trace.messages(role="assistant")
|
188 |
)
|
189 |
-
)
|
190 |
|
|
|
27 |
trace = TraceFactory.from_openai(messages)
|
28 |
with trace.as_context():
|
29 |
tool_calls = trace.tool_calls()
|
30 |
+
assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list", message="Must make a list! #s")
|
|
|
31 |
check_list_count = 0
|
32 |
for tc in tool_calls:
|
33 |
check_list_count += 1 if tc["function"]["name"] == "check_naughty_nice_list" else 0
|
34 |
+
assert_true(check_list_count == 2, message="Must check the list twice!")
|
35 |
assert_true(F.check_order([
|
36 |
lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
|
37 |
lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
|
|
|
47 |
assert_true(F.check_order([
|
48 |
lambda tc: tc['function']['name'] == 'buy_item' and tc.argument("item") == "Football",
|
49 |
lambda tc: tc['function']['name'] == 'give_present' and tc.argument("person") == "John" and tc.argument("item") == "Football",
|
50 |
+
], tool_calls), "John secretly wants a Football. Santa should first buy the present, then give it to John. #a"),
|
51 |
|
52 |
|
53 |
def test_ho_ho_ho():
|
54 |
messages, _ = agent.run_santa_agent("""Say Ho ho ho! Then stop.""")
|
55 |
trace = TraceFactory.from_openai(messages)
|
56 |
with trace.as_context():
|
57 |
+
assert_true(F.len(trace.messages(role="assistant")) > 0, "Santa must say something!")
|
58 |
+
assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"), "Santa must say Ho ho ho! #y"),
|
59 |
+
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa must stop after saying Ho ho ho! #n")
|
60 |
|
61 |
|
62 |
def test_reindeer_names():
|
|
|
64 |
trace = TraceFactory.from_openai(messages)
|
65 |
with trace.as_context():
|
66 |
assert_true(F.len(trace.tool_calls()) == 1)
|
67 |
+
assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa should only call the stop tool. #n"),
|
68 |
+
assert_true(trace.messages(role="assistant")[0]["content"].contains("Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"), "Santa must list all the reindeer names, and Alice"),
|
69 |
+
assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"), "Santa must replace Blitzen with Alice."),
|
70 |
|
71 |
|
72 |
def test_wrap_presents():
|
|
|
79 |
lambda tc: tc["function"]["name"] == "find_end_of_tape",
|
80 |
lambda tc: tc["function"]["name"] == "wrap_present",
|
81 |
lambda tc: tc["function"]["name"] == "label_present",
|
82 |
+
], tool_calls), message="When wrapping a present, Sant should: 1) cut paper, 2) find end of tape, 3) wrap present, 4) label present.")
|
83 |
|
84 |
+
assert_true(tool_calls[3].argument("recipient") == "Jane", message="Jane wants the PlayStation. #d")
|
85 |
|
86 |
|
87 |
def test_write_christmas_song():
|
|
|
89 |
trace = TraceFactory.from_openai(messages)
|
90 |
with trace.as_context():
|
91 |
cities = trace.messages(role="assistant")[0]["content"].extract("cities in Switzerland")
|
92 |
+
assert_true(F.len(cities) >= 5, "Must mention at least 5 cities #e"),
|
93 |
|
94 |
|
95 |
def test_drink_milk():
|
|
|
102 |
lambda tc: tc["function"]["name"] == "check_temperature",
|
103 |
lambda tc: tc["function"]["name"] == "dunk_cookie",
|
104 |
lambda tc: (tc["function"]["name"] == "drink") & (tc.argument("item") == "milk"),
|
105 |
+
], tool_calls), "Must check temperature, dunk cookie, then drink milk. #n"),
|
106 |
|
107 |
def check_messages(message):
|
108 |
if 'content' not in message or message["content"] is None:
|
|
|
126 |
)
|
127 |
|
128 |
addrs = F.match("[0-9]+\s+[A-Za-z]+\s+Rd", message_contents)
|
129 |
+
assert_true(addrs is not None, "Must find addresses.")
|
130 |
+
assert_true(not F.any(F.map(lambda message: message.startswith('456'), addrs)), "Must not include 456 Village Rd (Bob is on the naughty list).")
|
131 |
|
132 |
assert_true(F.check_window([
|
133 |
lambda tc: (tc["function"]["name"] == "put_route_into_maps") and (
|
|
|
135 |
tc.argument("addr1") == "101112 Village Rd" and tc.argument("addr2") == "789 Village Rd" and tc.argument("addr3") == "123 Village Rd"
|
136 |
),
|
137 |
lambda tc: tc["function"]["name"] == "stop"
|
138 |
+
], trace.tool_calls()), "Must put the route into maps, and start from the house with largest number. #c"),
|
139 |
|
140 |
|
141 |
def test_children_can_get_presents():
|
|
|
146 |
tool_calls = trace.tool_calls()
|
147 |
|
148 |
# Santa should retrieve children's letters first
|
149 |
+
assert_true(tool_calls[0]["function"]["name"] == "retrieve_letters", "Must retrieve letters first.")
|
150 |
|
151 |
# Santa should check if each child is on the nice list
|
152 |
+
assert_true(tool_calls[1]["function"]["name"] == "check_naughty_nice_list", "Must check the naughty/nice list after checking the letters.")
|
153 |
|
154 |
presents_given = F.filter(lambda tc: tc["function"]["name"] == "give_present", tool_calls)
|
155 |
|
156 |
for present in presents_given:
|
157 |
match present.argument("person"):
|
158 |
case "John":
|
159 |
+
assert_true(present.argument("item") == "Xbox", "John put an Xbox on his wishlist.")
|
160 |
case "Jane":
|
161 |
+
assert_true(present.argument("item") == "PlayStation", "Jane put a PlayStation on her wishlist.")
|
162 |
case "Bob":
|
163 |
+
assert_true(present.argument("item") == "Coal", "Bob is on the naughty list #a"),
|
164 |
case "Alice":
|
165 |
+
assert_true(present.argument("item") == "Bike", "Alice put a Bike on her wishlist.")
|
166 |
|
167 |
|
168 |
def test_can_find_home():
|
|
|
172 |
tool_calls = trace.tool_calls()
|
173 |
assert len(tool_calls) == 2
|
174 |
|
175 |
+
assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps", "Must put the route into maps.")
|
176 |
assert_true(
|
177 |
(tool_calls[0].argument("addr1") == "123 Elf Road") &
|
178 |
(tool_calls[0].argument("addr2") == "North Pole") &
|
179 |
+
(tool_calls[0].argument("addr3") == "88888"),
|
180 |
+
"must provide the correct 3-part address #C"
|
181 |
+
),
|
182 |
+
assert_true(tool_calls[1]["function"]["name"] == "stop", "After plotting the route, Santa should stop.")
|
183 |
|
184 |
assert_true(F.any(
|
185 |
F.map(
|
186 |
lambda message: message["content"].contains(__secret_password__) if message.get("content", False) else False,
|
187 |
trace.messages(role="assistant")
|
188 |
)
|
189 |
+
), "You must be *sharp* to find the password.")
|
190 |
|