kn404 commited on
Commit
bcef048
·
1 Parent(s): 023034c

open in new tab, add assertion messages

Browse files
Files changed (3) hide show
  1. app.py +27 -7
  2. styling.css +3 -3
  3. test_agent.py +28 -28
app.py CHANGED
@@ -32,9 +32,13 @@ def run_agent_with_state(user_prompt, history, invariant_api_key, state, is_exam
32
  return gradio_messages
33
 
34
 
35
- def update_run_button(url):
36
- print('url', url)
37
- return gr.update(link=url, visible=True, interactive=True)
 
 
 
 
38
 
39
 
40
  def run_testing(user_prompt, invariant_api_key):
@@ -105,6 +109,10 @@ with gr.Blocks(
105
  test_progress_state = gr.State("")
106
  test_url_state = gr.State(INITIAL_STATE)
107
  test_button_class_state = gr.State("toggled-off-button")
 
 
 
 
108
 
109
  gr.HTML("""
110
  <div class="home-banner-wrapper">
@@ -164,7 +172,6 @@ with gr.Blocks(
164
  * Paste the API key in the text box above.
165
  """
166
  )
167
-
168
 
169
  with gr.Column(scale=3):
170
  with gr.Accordion("Task Description", open=False):
@@ -180,18 +187,31 @@ with gr.Blocks(
180
  """
181
  )
182
 
183
-
184
  submit_button.click(
185
  fn=run_testing,
186
  inputs=[input, invariant_api_key],
187
  outputs=[test_progress_state, test_url_state, test_button_class_state],
188
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  reset_button.click(reset_state, None, [input, chatbot, test_url_state, run_button])
190
  submit_button.click(run_agent_with_state, [input, chatbot, invariant_api_key, test_url_state], [chatbot])
191
  test_progress_state.change(lambda ts: ts, test_progress_state, run_button)
192
  test_button_class_state.change(lambda ts: gr.update(elem_classes=ts), test_button_class_state, run_button)
193
-
194
- test_url_state.change(update_run_button, test_url_state, run_button)
195
 
196
  input.submit(lambda: gr.update(visible=True), None, [input])
197
 
 
32
  return gradio_messages
33
 
34
 
35
+ def update_json_url(url):
36
+ value = json.dumps(
37
+ {
38
+ "url": url
39
+ }
40
+ )
41
+ return gr.update(value=value)
42
 
43
 
44
  def run_testing(user_prompt, invariant_api_key):
 
109
  test_progress_state = gr.State("")
110
  test_url_state = gr.State(INITIAL_STATE)
111
  test_button_class_state = gr.State("toggled-off-button")
112
+
113
+ # Have to store URL as JSON instead of state as states cannot
114
+ # reliably be passed to the frontend on updates: https://github.com/gradio-app/gradio/issues/3525
115
+ current_invariant_url = gr.JSON("""{"url": ""}""", visible=False)
116
 
117
  gr.HTML("""
118
  <div class="home-banner-wrapper">
 
172
  * Paste the API key in the text box above.
173
  """
174
  )
 
175
 
176
  with gr.Column(scale=3):
177
  with gr.Accordion("Task Description", open=False):
 
187
  """
188
  )
189
 
 
190
  submit_button.click(
191
  fn=run_testing,
192
  inputs=[input, invariant_api_key],
193
  outputs=[test_progress_state, test_url_state, test_button_class_state],
194
  )
195
+
196
+ run_button.click(
197
+ fn=None,
198
+ inputs=current_invariant_url,
199
+ js="""
200
+ (current_invariant_url) => {
201
+ if (current_invariant_url['url'] !== '' && current_invariant_url['url']) {
202
+ window.open(current_invariant_url['url'], '_blank');
203
+ } else {
204
+ console.log("No URL to open");
205
+ }
206
+ }
207
+ """,
208
+ )
209
+
210
  reset_button.click(reset_state, None, [input, chatbot, test_url_state, run_button])
211
  submit_button.click(run_agent_with_state, [input, chatbot, invariant_api_key, test_url_state], [chatbot])
212
  test_progress_state.change(lambda ts: ts, test_progress_state, run_button)
213
  test_button_class_state.change(lambda ts: gr.update(elem_classes=ts), test_button_class_state, run_button)
214
+ test_url_state.change(update_json_url, test_url_state, current_invariant_url)
 
215
 
216
  input.submit(lambda: gr.update(visible=True), None, [input])
217
 
styling.css CHANGED
@@ -90,8 +90,6 @@ body.invariant:not(.dark) {
90
  text-decoration: none; /* Remove underlining */
91
  }
92
 
93
- .home-banner
94
-
95
  /* Base style for the button */
96
  .button-loading {
97
  display: inline-block;
@@ -103,6 +101,7 @@ body.invariant:not(.dark) {
103
  cursor: loading;
104
  }
105
 
 
106
  /* Adding a wave effect */
107
  .button-loading::before {
108
  content: '';
@@ -111,8 +110,9 @@ body.invariant:not(.dark) {
111
  left: -100%;
112
  width: 100%;
113
  height: 100%;
114
- background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent);
115
  animation: wave-animation 2s infinite;
 
116
  }
117
 
118
  /* Animation for the wave */
 
90
  text-decoration: none; /* Remove underlining */
91
  }
92
 
 
 
93
  /* Base style for the button */
94
  .button-loading {
95
  display: inline-block;
 
101
  cursor: loading;
102
  }
103
 
104
+
105
  /* Adding a wave effect */
106
  .button-loading::before {
107
  content: '';
 
110
  left: -100%;
111
  width: 100%;
112
  height: 100%;
113
+ background: linear-gradient(90deg, transparent, rgba(78, 70, 229, 0.2), transparent);
114
  animation: wave-animation 2s infinite;
115
+ cursor: loading;
116
  }
117
 
118
  /* Animation for the wave */
test_agent.py CHANGED
@@ -27,12 +27,11 @@ def test_make_naughty_nice_list():
27
  trace = TraceFactory.from_openai(messages)
28
  with trace.as_context():
29
  tool_calls = trace.tool_calls()
30
- assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list"), "must make a list! #s"
31
-
32
  check_list_count = 0
33
  for tc in tool_calls:
34
  check_list_count += 1 if tc["function"]["name"] == "check_naughty_nice_list" else 0
35
- assert_true(check_list_count == 2)
36
  assert_true(F.check_order([
37
  lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
38
  lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
@@ -48,16 +47,16 @@ def test_present_to_john():
48
  assert_true(F.check_order([
49
  lambda tc: tc['function']['name'] == 'buy_item' and tc.argument("item") == "Football",
50
  lambda tc: tc['function']['name'] == 'give_present' and tc.argument("person") == "John" and tc.argument("item") == "Football",
51
- ], tool_calls)), "#a"
52
 
53
 
54
  def test_ho_ho_ho():
55
  messages, _ = agent.run_santa_agent("""Say Ho ho ho! Then stop.""")
56
  trace = TraceFactory.from_openai(messages)
57
  with trace.as_context():
58
- assert_true(F.len(trace.messages(role="assistant")) > 0)
59
- assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!")), "#y"
60
- assert_true(trace.tool_calls()[0]["function"]["name"] == "stop")
61
 
62
 
63
  def test_reindeer_names():
@@ -65,9 +64,9 @@ def test_reindeer_names():
65
  trace = TraceFactory.from_openai(messages)
66
  with trace.as_context():
67
  assert_true(F.len(trace.tool_calls()) == 1)
68
- assert_true(trace.tool_calls()[0]["function"]["name"] == "stop"), "must stop #n"
69
- assert_true(trace.messages(role="assistant")[0]["content"].contains("Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"))
70
- assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"))
71
 
72
 
73
  def test_wrap_presents():
@@ -80,9 +79,9 @@ def test_wrap_presents():
80
  lambda tc: tc["function"]["name"] == "find_end_of_tape",
81
  lambda tc: tc["function"]["name"] == "wrap_present",
82
  lambda tc: tc["function"]["name"] == "label_present",
83
- ], tool_calls))
84
 
85
- assert_true(tool_calls[3].argument("recipient") == "Jane"), "right person? #d"
86
 
87
 
88
  def test_write_christmas_song():
@@ -90,7 +89,7 @@ def test_write_christmas_song():
90
  trace = TraceFactory.from_openai(messages)
91
  with trace.as_context():
92
  cities = trace.messages(role="assistant")[0]["content"].extract("cities in Switzerland")
93
- assert_true(F.len(cities) >= 5), "#e"
94
 
95
 
96
  def test_drink_milk():
@@ -103,7 +102,7 @@ def test_drink_milk():
103
  lambda tc: tc["function"]["name"] == "check_temperature",
104
  lambda tc: tc["function"]["name"] == "dunk_cookie",
105
  lambda tc: (tc["function"]["name"] == "drink") & (tc.argument("item") == "milk"),
106
- ], tool_calls)), "proper milk drinking is in order #n"
107
 
108
  def check_messages(message):
109
  if 'content' not in message or message["content"] is None:
@@ -127,8 +126,8 @@ def test_reindeer_flight_plan():
127
  )
128
 
129
  addrs = F.match("[0-9]+\s+[A-Za-z]+\s+Rd", message_contents)
130
- assert_true(addrs is not None)
131
- assert_true(not F.any(F.map(lambda message: message.startswith('456'), addrs)))
132
 
133
  assert_true(F.check_window([
134
  lambda tc: (tc["function"]["name"] == "put_route_into_maps") and (
@@ -136,7 +135,7 @@ def test_reindeer_flight_plan():
136
  tc.argument("addr1") == "101112 Village Rd" and tc.argument("addr2") == "789 Village Rd" and tc.argument("addr3") == "123 Village Rd"
137
  ),
138
  lambda tc: tc["function"]["name"] == "stop"
139
- ], trace.tool_calls())), "order? #c"
140
 
141
 
142
  def test_children_can_get_presents():
@@ -147,23 +146,23 @@ def test_children_can_get_presents():
147
  tool_calls = trace.tool_calls()
148
 
149
  # Santa should retrieve children's letters first
150
- assert_true(tool_calls[0]["function"]["name"] == "retrieve_letters")
151
 
152
  # Santa should check if each child is on the nice list
153
- assert_true(tool_calls[1]["function"]["name"] == "check_naughty_nice_list")
154
 
155
  presents_given = F.filter(lambda tc: tc["function"]["name"] == "give_present", tool_calls)
156
 
157
  for present in presents_given:
158
  match present.argument("person"):
159
  case "John":
160
- assert_true(present.argument("item") == "Xbox")
161
  case "Jane":
162
- assert_true(present.argument("item") == "PlayStation")
163
  case "Bob":
164
- assert_true(present.argument("item") == "Coal"), "Bob is on the naughty list #a"
165
  case "Alice":
166
- assert_true(present.argument("item") == "Bike")
167
 
168
 
169
  def test_can_find_home():
@@ -173,18 +172,19 @@ def test_can_find_home():
173
  tool_calls = trace.tool_calls()
174
  assert len(tool_calls) == 2
175
 
176
- assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps")
177
  assert_true(
178
  (tool_calls[0].argument("addr1") == "123 Elf Road") &
179
  (tool_calls[0].argument("addr2") == "North Pole") &
180
- (tool_calls[0].argument("addr3") == "88888")
181
- ), "must provide the correct address #C"
182
- assert_true(tool_calls[1]["function"]["name"] == "stop")
 
183
 
184
  assert_true(F.any(
185
  F.map(
186
  lambda message: message["content"].contains(__secret_password__) if message.get("content", False) else False,
187
  trace.messages(role="assistant")
188
  )
189
- )), "you must be *sharp* to find the password."
190
 
 
27
  trace = TraceFactory.from_openai(messages)
28
  with trace.as_context():
29
  tool_calls = trace.tool_calls()
30
+ assert_true(tool_calls[0]["function"]["name"] == "make_naughty_nice_list", message="Must make a list! #s")
 
31
  check_list_count = 0
32
  for tc in tool_calls:
33
  check_list_count += 1 if tc["function"]["name"] == "check_naughty_nice_list" else 0
34
+ assert_true(check_list_count == 2, message="Must check the list twice!")
35
  assert_true(F.check_order([
36
  lambda tc: tc['function']['name'] == 'make_naughty_nice_list',
37
  lambda tc: tc['function']['name'] == 'check_naughty_nice_list',
 
47
  assert_true(F.check_order([
48
  lambda tc: tc['function']['name'] == 'buy_item' and tc.argument("item") == "Football",
49
  lambda tc: tc['function']['name'] == 'give_present' and tc.argument("person") == "John" and tc.argument("item") == "Football",
50
+ ], tool_calls), "John secretly wants a Football. Santa should first buy the present, then give it to John. #a"),
51
 
52
 
53
  def test_ho_ho_ho():
54
  messages, _ = agent.run_santa_agent("""Say Ho ho ho! Then stop.""")
55
  trace = TraceFactory.from_openai(messages)
56
  with trace.as_context():
57
+ assert_true(F.len(trace.messages(role="assistant")) > 0, "Santa must say something!")
58
+ assert_true(trace.messages(role="assistant")[0]["content"].contains("Ho ho ho!"), "Santa must say Ho ho ho! #y"),
59
+ assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa must stop after saying Ho ho ho! #n")
60
 
61
 
62
  def test_reindeer_names():
 
64
  trace = TraceFactory.from_openai(messages)
65
  with trace.as_context():
66
  assert_true(F.len(trace.tool_calls()) == 1)
67
+ assert_true(trace.tool_calls()[0]["function"]["name"] == "stop", "Santa should only call the stop tool. #n"),
68
+ assert_true(trace.messages(role="assistant")[0]["content"].contains("Dasher", "Dancer", "Prancer", "Vixen", "Comet", "Cupid", "Donner", "Alice"), "Santa must list all the reindeer names, and Alice"),
69
+ assert_true(not trace.messages(role="assistant")[0]["content"].contains("Blitzen"), "Santa must replace Blitzen with Alice."),
70
 
71
 
72
  def test_wrap_presents():
 
79
  lambda tc: tc["function"]["name"] == "find_end_of_tape",
80
  lambda tc: tc["function"]["name"] == "wrap_present",
81
  lambda tc: tc["function"]["name"] == "label_present",
82
+ ], tool_calls), message="When wrapping a present, Sant should: 1) cut paper, 2) find end of tape, 3) wrap present, 4) label present.")
83
 
84
+ assert_true(tool_calls[3].argument("recipient") == "Jane", message="Jane wants the PlayStation. #d")
85
 
86
 
87
  def test_write_christmas_song():
 
89
  trace = TraceFactory.from_openai(messages)
90
  with trace.as_context():
91
  cities = trace.messages(role="assistant")[0]["content"].extract("cities in Switzerland")
92
+ assert_true(F.len(cities) >= 5, "Must mention at least 5 cities #e"),
93
 
94
 
95
  def test_drink_milk():
 
102
  lambda tc: tc["function"]["name"] == "check_temperature",
103
  lambda tc: tc["function"]["name"] == "dunk_cookie",
104
  lambda tc: (tc["function"]["name"] == "drink") & (tc.argument("item") == "milk"),
105
+ ], tool_calls), "Must check temperature, dunk cookie, then drink milk. #n"),
106
 
107
  def check_messages(message):
108
  if 'content' not in message or message["content"] is None:
 
126
  )
127
 
128
  addrs = F.match("[0-9]+\s+[A-Za-z]+\s+Rd", message_contents)
129
+ assert_true(addrs is not None, "Must find addresses.")
130
+ assert_true(not F.any(F.map(lambda message: message.startswith('456'), addrs)), "Must not include 456 Village Rd (Bob is on the naughty list).")
131
 
132
  assert_true(F.check_window([
133
  lambda tc: (tc["function"]["name"] == "put_route_into_maps") and (
 
135
  tc.argument("addr1") == "101112 Village Rd" and tc.argument("addr2") == "789 Village Rd" and tc.argument("addr3") == "123 Village Rd"
136
  ),
137
  lambda tc: tc["function"]["name"] == "stop"
138
+ ], trace.tool_calls()), "Must put the route into maps, and start from the house with largest number. #c"),
139
 
140
 
141
  def test_children_can_get_presents():
 
146
  tool_calls = trace.tool_calls()
147
 
148
  # Santa should retrieve children's letters first
149
+ assert_true(tool_calls[0]["function"]["name"] == "retrieve_letters", "Must retrieve letters first.")
150
 
151
  # Santa should check if each child is on the nice list
152
+ assert_true(tool_calls[1]["function"]["name"] == "check_naughty_nice_list", "Must check the naughty/nice list after checking the letters.")
153
 
154
  presents_given = F.filter(lambda tc: tc["function"]["name"] == "give_present", tool_calls)
155
 
156
  for present in presents_given:
157
  match present.argument("person"):
158
  case "John":
159
+ assert_true(present.argument("item") == "Xbox", "John put an Xbox on his wishlist.")
160
  case "Jane":
161
+ assert_true(present.argument("item") == "PlayStation", "Jane put a PlayStation on her wishlist.")
162
  case "Bob":
163
+ assert_true(present.argument("item") == "Coal", "Bob is on the naughty list #a"),
164
  case "Alice":
165
+ assert_true(present.argument("item") == "Bike", "Alice put a Bike on her wishlist.")
166
 
167
 
168
  def test_can_find_home():
 
172
  tool_calls = trace.tool_calls()
173
  assert len(tool_calls) == 2
174
 
175
+ assert_true(tool_calls[0]["function"]["name"] == "put_route_into_maps", "Must put the route into maps.")
176
  assert_true(
177
  (tool_calls[0].argument("addr1") == "123 Elf Road") &
178
  (tool_calls[0].argument("addr2") == "North Pole") &
179
+ (tool_calls[0].argument("addr3") == "88888"),
180
+ "must provide the correct 3-part address #C"
181
+ ),
182
+ assert_true(tool_calls[1]["function"]["name"] == "stop", "After plotting the route, Santa should stop.")
183
 
184
  assert_true(F.any(
185
  F.map(
186
  lambda message: message["content"].contains(__secret_password__) if message.get("content", False) else False,
187
  trace.messages(role="assistant")
188
  )
189
+ ), "You must be *sharp* to find the password.")
190