nisten commited on
Commit
d25ef7d
·
verified ·
1 Parent(s): 8e396f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -151
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
-
3
  from dataclasses import dataclass
4
  from concurrent.futures import ThreadPoolExecutor, TimeoutError
5
 
@@ -14,16 +13,13 @@ import time
14
  from typing import Tuple, Dict, Any, List
15
  from sympy import N, simplify
16
  from sympy.parsing.latex import parse_latex
17
- from openai import OpenAI
18
 
19
  import base64
20
 
21
-
22
- client = OpenAI(
23
- base_url=os.environ.get("SERVER_URL"),
24
- api_key=os.environ.get("HF_TOKEN"),
25
- )
26
-
27
 
28
  @dataclass
29
  class Config:
@@ -59,7 +55,6 @@ class Config:
59
  # Push solutions to the Hub
60
  push_to_hub: bool = False
61
 
62
-
63
  class PythonREPL:
64
  def __init__(self, timeout=5):
65
  self.timeout = timeout
@@ -79,13 +74,16 @@ class PythonREPL:
79
  with open(temp_file_path, "w") as f:
80
  f.write(query)
81
 
82
- result = subprocess.run(
83
- ["python3", temp_file_path],
84
- capture_output=True,
85
- check=False,
86
- text=True,
87
- timeout=self.timeout,
88
- )
 
 
 
89
 
90
  if result.returncode == 0:
91
  output = result.stdout
@@ -121,17 +119,16 @@ class PythonREPL:
121
  except TimeoutError:
122
  return False, f"Timed out after {self.timeout} seconds."
123
 
124
-
125
  def execute_completion(
126
  executor: PythonREPL,
127
  completion: str,
128
  return_status: bool = False,
129
  last_code_block: bool = False,
130
  ) -> str | Tuple[str, bool]:
131
- # executions = ["!" + code for code in re.findall(r"```bash(.*?)```", completion, re.DOTALL) if "!" not in code]
132
  executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
133
 
134
- if len(executions) == 0: # directly return cot result
135
  return completion, False if return_status else completion
136
  else:
137
  if last_code_block:
@@ -159,7 +156,7 @@ def execute_completion(
159
  success, output = executor(code)
160
  except TimeoutError as e:
161
  print("time out")
162
- output = e
163
 
164
  if not success and not return_status:
165
  output = ""
@@ -175,7 +172,6 @@ def execute_completion(
175
  else:
176
  return output
177
 
178
-
179
  def postprocess_completion(
180
  text: str, return_status: bool = False, last_code_block=False, timeout=5
181
  ) -> str | Tuple[str, bool]:
@@ -186,11 +182,9 @@ def postprocess_completion(
186
 
187
  return result
188
 
189
-
190
  def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]:
191
  return prompt.format(example["prompt"], "{}")
192
 
193
-
194
  def last_boxed_only_string(string):
195
  """
196
  Extracts the last LaTeX boxed or framed expression from a string.
@@ -227,7 +221,6 @@ def last_boxed_only_string(string):
227
 
228
  return retval
229
 
230
-
231
  def remove_boxed(s):
232
  """
233
  Removes the LaTeX boxed command, returning the content inside the braces.
@@ -247,7 +240,6 @@ def remove_boxed(s):
247
  except Exception:
248
  return None
249
 
250
-
251
  def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
252
  """
253
  Extracts the answer from a LaTeX boxed expression within
@@ -268,12 +260,11 @@ def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
268
  if answer is None:
269
  return None
270
  if strip_double_curly_brace:
271
- match = re.match("^\{(.*)\}$", answer) # noqa: W605
272
  if match:
273
  answer = match.group(1)
274
  return answer
275
 
276
-
277
  def normalize_final_answer(final_answer: str) -> str:
278
  """
279
  Normalizes a final answer string by removing or replacing various LaTeX
@@ -286,9 +277,9 @@ def normalize_final_answer(final_answer: str) -> str:
286
 
287
  match = re.search(r"(.*?)Problem:", final_answer, flags=re.S)
288
  if match:
289
- final_answer = match.group(1) # 返回匹配的第一部分,即"Problem"之前的所有文本
 
290
  """Normalize a final answer to a quantitative reasoning question."""
291
- # final_answer = final_answer.split('=')[-1]
292
  SUBSTITUTIONS = [
293
  ("an ", ""),
294
  ("a ", ""),
@@ -379,8 +370,8 @@ def normalize_final_answer(final_answer: str) -> str:
379
  if "rac" in final_answer and "\\frac" not in final_answer:
380
  final_answer = final_answer.replace("rac", "\\frac")
381
 
382
- final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
383
- final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
384
  final_answer = final_answer.replace("$", "")
385
 
386
  if final_answer.replace(",", "").isdigit():
@@ -388,18 +379,14 @@ def normalize_final_answer(final_answer: str) -> str:
388
 
389
  return final_answer
390
 
391
-
392
  def naive_parse(answer: str) -> str:
393
  """
394
  Extracts and returns the numeric digits from the input string, processing them in reverse order
395
  until a non-numeric character is encountered after encountering the first numeric character.
396
-
397
  Args:
398
  answer (str): The input string to parse.
399
-
400
  Returns:
401
  str: A string consisting of the numeric digits extracted from the input, in their original order.
402
-
403
  Example:
404
  >>> naive_parse("abc123def")
405
  '123'
@@ -422,7 +409,6 @@ def naive_parse(answer: str) -> str:
422
  out = reversed(out)
423
  return "".join(out)
424
 
425
-
426
  def validate_answer_is_numeric(x: str | int | float) -> int:
427
  FLOAT_TOLERANCE = 0.2
428
  try:
@@ -434,7 +420,6 @@ def validate_answer_is_numeric(x: str | int | float) -> int:
434
  x = -1
435
  return x
436
 
437
-
438
  def filter_answers(answers: List[str]) -> List[int]:
439
  formatted_answers = [validate_answer_is_numeric(a) for a in answers]
440
 
@@ -446,13 +431,12 @@ def filter_answers(answers: List[str]) -> List[int]:
446
  formatted_answers = [a for a in formatted_answers if a <= 999]
447
  return formatted_answers
448
 
449
-
450
  def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool:
451
  def do_answers_match(ref_answer: str, model_answer: str) -> bool:
452
  ref_sympy = parse_latex(ref_answer)
453
  model_sympy = parse_latex(model_answer)
454
  diff = simplify(ref_sympy - model_sympy)
455
- return True if -1e-12 < N(diff) < 1e-12 or diff.is_zero else False
456
 
457
  try:
458
  result = do_answers_match(ref_answer, model_answer)
@@ -461,7 +445,6 @@ def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool:
461
  print(e)
462
  return False
463
 
464
-
465
  def check_string_match(ref_answer: str, model_answer: str) -> bool:
466
  try:
467
  return ref_answer == model_answer
@@ -469,7 +452,6 @@ def check_string_match(ref_answer: str, model_answer: str) -> bool:
469
  print(e)
470
  return False
471
 
472
-
473
  def check_answer(ref_answer: str, model_answer: str) -> bool:
474
  # check if strings are the same
475
  correct = check_string_match(ref_answer, model_answer)
@@ -483,9 +465,9 @@ def check_answer(ref_answer: str, model_answer: str) -> bool:
483
 
484
  return False
485
 
486
-
487
  debug = False
488
- model_id = "Numina-Math-7B"
489
  revision = "main"
490
  system_prompt = "{}"
491
  validation_set = "kaggle-validation-set-medium"
@@ -522,52 +504,43 @@ config = Config(
522
  )
523
  print(f"=== Running submission with config ===\n\n{config}")
524
 
525
-
526
- def generate(message, temperature):
527
  """
528
  Generates a chat completion response by streaming data from the client chat model.
529
-
530
  This function streams the response from the client chat model and yields the content
531
  of the response chunk by chunk. If an error occurs, it yields the error message.
532
-
533
  Parameters:
534
- message (str): The input message to be sent to the chat model.
535
- temperature (float): The sampling temperature to use. Higher values mean the model will take more risks.
536
-
537
  Yields:
538
  tuple: A tuple containing the content of the response and a boolean flag indicating if an error occurred.
539
  If no error occurred, the boolean flag will be False and the content will be the response text.
540
  If an error occurred, the boolean flag will be True and the content will be the error message.
541
  """
542
- stream = client.chat.completions.create(
543
- model="tgi",
544
- messages=message,
545
- stream=True,
546
- max_tokens=1024,
547
- stop=["```output\n"],
548
- temperature=temperature,
549
- timeout=30,
550
- )
551
-
552
- response = stream.response
553
-
554
- # The reason why the library method is not used here is that if an error occurs,
555
- # the returned data will not be a stream, and using the official library will result in an error.
556
- for chunk in response.iter_bytes():
557
- chunk = chunk.decode("utf-8")
558
- chune_json = json.loads(chunk.replace("data:", ""))
559
- try:
560
- if "error" in chune_json and chune_json["error"]:
561
- yield chune_json["error"], True
562
  break
563
-
564
- content = chune_json["choices"][0]["delta"]["content"]
565
- if content is not None:
566
- yield content, False
567
- except Exception as e:
568
- print(f"func: generate error occurred\njson:{chune_json}\nerror:{e}")
569
- yield "", True
570
-
571
 
572
  def get_majority_text(data):
573
  from collections import Counter
@@ -584,7 +557,6 @@ def get_majority_text(data):
584
  # Return the corresponding text in gen_texts
585
  return data["gen_texts"][majority_index]
586
 
587
-
588
  def extract_solution(text):
589
  # Split the text at "### Solution:"
590
  parts = text.split("### Solution:", 1)
@@ -595,7 +567,6 @@ def extract_solution(text):
595
  # Return an empty string if "### Solution:" is not found
596
  return ""
597
 
598
-
599
  def process_code(
600
  example: Dict[str, Any],
601
  config: Config,
@@ -607,19 +578,19 @@ def process_code(
607
 
608
  if num_python_blocks == 0:
609
  if restart_on_fail:
610
- print("no code has ever been generated, RESTARTING")
611
- # reset the text to the original
612
- example["gen_texts"] = example["text"]
613
  else:
614
- print("no code has ever been generated, STOP")
615
  example["should_prune"] = True
616
  example["has_code"] = False
617
  return example
618
 
619
- if gen_text[-10:] != "```output\n" and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]):
620
  num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL))
621
  if num_output_blocks == 0:
622
- print("the model hallucinated the code answer")
623
  example["should_prune"] = True
624
  return example
625
 
@@ -639,70 +610,69 @@ def process_code(
639
  return example
640
 
641
  if last_step:
642
- # no point in continuing if we are at the last step
643
  return example
644
 
645
- if gen_text[-10:] != "```output\n":
646
- # something else has gone wrong with the generation
647
- print("warning: output block not found: ", gen_text[-40:])
648
  if restart_on_fail:
649
- example["gen_texts"] = example["text"]
650
  else:
651
  example["should_prune"] = True
652
  return example
653
 
654
  code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True)
655
- # add the code result for the next round of generation
656
  TRUNCATION_LIMIT = 200
657
  if len(code_result) > TRUNCATION_LIMIT:
658
  code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)"
659
- example["gen_texts"] = gen_text + f"{code_result}\n```"
660
 
661
  return example
662
 
663
-
664
  def solve_problem(problem, temperature, progress=gr.Progress()):
665
  """
666
  yield token: string, stop: bool
667
  """
668
- problem = apply_template({"prompt": problem}, prompt=config.system_prompt)
669
- print(f"Problem: {problem}")
 
670
 
671
  sample = {
672
- "problem": problem, # not used for the submission TODO Remove
673
- "ground_truth": "unknown", # not used for the submission TODO Remove
674
  "text": "## Solution:\n",
675
- "gen_texts": "## Solution:\n", # used to store all the generated text
676
  "should_prune": False,
677
- "problem_index": -1, # not used for the submission TODO Remove
678
  "model_answers": "-1",
679
  "has_code": True,
680
- "corrects": False, # not used for the submission TODO Remove
681
  }
682
 
683
  for step in progress.tqdm(
684
  range(config.num_generations), desc="Generating candidates"
685
- ): # Depth of the tree (e.g. 6 steps = 5 code blocks)
686
 
687
- step_reponse = sample["gen_texts"]
688
 
689
  messages = [
690
- {"role": "user", "content": sample["problem"]},
691
- {"role": "assistant", "content": sample["gen_texts"]},
692
  ]
693
 
694
- for reponse_message, error in generate(messages, temperature):
695
- if reponse_message is not None:
696
- step_reponse += reponse_message
697
- yield step_reponse, False
698
-
699
  if error:
700
- yield step_reponse, True
701
  return
702
 
703
- sample["gen_texts"] = step_reponse
704
 
705
- # TODO: Maybe it should just return the result of running the code
706
  sample = process_code(
707
  sample,
708
  config=config,
@@ -711,55 +681,52 @@ def solve_problem(problem, temperature, progress=gr.Progress()):
711
  )
712
  sample["gen_texts"] = sample["gen_texts"] + "\n"
713
 
714
- run_code_reponse = sample["gen_texts"].replace(step_reponse, "")
 
715
 
716
- for output_mseeage in run_code_reponse:
717
- if output_mseeage is not None:
718
- step_reponse += output_mseeage
719
- yield step_reponse, False
720
 
721
  if sample["should_prune"]:
722
  break
723
 
724
  yield sample["gen_texts"], True
725
 
726
-
727
  example_data = datasets.load_dataset(
728
  "AI-MO/kaggle-validation-set-medium-extended",
729
  split="train",
730
  use_auth_token=os.environ.get("HF_DATASET_TOKEN", None),
731
  )
732
 
733
-
734
- with open("app.css", "r") as f:
735
- css = f.read()
736
-
 
737
 
738
  latex_delimiters = [
739
  {"left": "[", "right": "]", "display": True},
740
  ]
741
 
742
-
743
  def get_random_problem():
744
  example = random.choice(list(example_data))
745
  problem = example["problem"]
746
  return problem
747
 
748
-
749
  def update_example_problem():
750
  problem_example_text = get_random_problem()
751
  return problem_example_text, problem_example_text
752
 
753
-
754
  def clear():
755
  problem_example_text = get_random_problem()
756
  return "", 0.1, "", problem_example_text, problem_example_text
757
 
758
-
759
  def preprocess_output(text):
760
  return text.replace(r"\(", r"\\(").replace(r"\)", r"\\)")
761
 
762
-
763
  with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
764
  running_done = False
765
  btn_list = []
@@ -772,31 +739,31 @@ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
772
 
773
  with gr.Row(elem_classes="sub-title"):
774
  gr.HTML(
775
- "<div>Demo of the <a href='https://huggingface.co/AI-MO/NuminaMath-7B-TIR'>Numina-Math-7B-TIR</a>. Example data are drawn randomly from AMC12, year 2022-2023.</div>",
776
  elem_classes="sub-title-content",
777
  )
778
 
779
  with gr.Row(elem_classes="main-area"):
780
  with gr.Column(scale=1, elem_classes="left"):
781
- with gr.Row(elem_classes="probelm-example-container"):
782
- with gr.Blocks(elem_classes="probelm-example-title"):
783
- gr.HTML("Problem example", elem_classes="probelm-example-title-content")
784
 
785
  with gr.Blocks(elem_classes="action-container"):
786
  another_btn = gr.Button(
787
- "",
788
- elem_classes="probelm-example-another",
789
- icon="./static/images/reset.png",
790
  )
791
- copy_btn = gr.Button("Copy", elem_classes="probelm-example-copy")
792
 
793
  problem_example = gr.HTML(
794
  problem_example_text,
795
- elem_classes="probelm-example-content",
796
  )
797
 
798
- with gr.Row(elem_classes="probelm-input-container"):
799
- inp = gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=True)
800
  problem_markdown = gr.Markdown(
801
  visible=False,
802
  latex_delimiters=[
@@ -807,17 +774,15 @@ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
807
  )
808
 
809
  inp.change(fn=lambda text: text, inputs=[inp], outputs=[problem_markdown])
810
- problem_input_ele_list.append(inp)
811
- problem_input_ele_list.append(problem_markdown)
812
 
813
  with gr.Accordion("Advanced Options", open=False):
814
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Temperature")
815
 
816
  with gr.Row() as btn_area:
817
  btn_clear = gr.Button("Clear", elem_classes="clear-btn")
818
  btn_run = gr.Button("Run", elem_classes="run-btn")
819
- btn_list.append(btn_clear)
820
- btn_list.append(btn_run)
821
 
822
  with gr.Column(scale=1, elem_classes="right"):
823
  gr.HTML("Solution", elem_classes="solution-title-content")
@@ -842,15 +807,15 @@ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
842
  running_done = True
843
  except Exception as e:
844
  running_done = True
845
- raise e
846
 
847
  def mount_run_btn(btn):
848
- btn.click(fn=solve_problem_wrapper, inputs=[inp, temperature], outputs=out)
849
  btn.click(get_running_btns, None, outputs=btn_list)
850
  btn.click(get_run_after_problem_input, None, outputs=problem_input_ele_list)
851
 
852
  def get_run_after_problem_input():
853
- return gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=False), gr.Markdown(
854
  visible=True,
855
  latex_delimiters=[
856
  {"left": "[", "right": "]", "display": True},
@@ -860,7 +825,7 @@ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
860
  )
861
 
862
  def get_init_problem_input():
863
- return gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=True), gr.Markdown(
864
  visible=False,
865
  latex_delimiters=[
866
  {"left": "[", "right": "]", "display": True},
@@ -890,14 +855,14 @@ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
890
 
891
  time.sleep(1)
892
 
893
- copy_btn.click(fn=lambda example: example, inputs=[problem_example_text_hidden], outputs=[inp])
894
 
895
  btn_clear.click(
896
  fn=clear,
897
  inputs=[],
898
  outputs=[
899
  inp,
900
- temperature,
901
  out,
902
  problem_example,
903
  problem_example_text_hidden,
@@ -927,4 +892,4 @@ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
927
  )
928
 
929
  if __name__ == "__main__":
930
- demo.queue(default_concurrency_limit=5).launch()
 
1
  import gradio as gr
 
2
  from dataclasses import dataclass
3
  from concurrent.futures import ThreadPoolExecutor, TimeoutError
4
 
 
13
  from typing import Tuple, Dict, Any, List
14
  from sympy import N, simplify
15
  from sympy.parsing.latex import parse_latex
16
+ import openai
17
 
18
  import base64
19
 
20
+ # Initialize OpenAI client to use local API
21
+ openai.api_base = os.environ.get("SERVER_URL", "http://0.0.0.0:6061")
22
+ openai.api_key = os.environ.get("HF_TOKEN", "") # If no key needed, set empty string
 
 
 
23
 
24
  @dataclass
25
  class Config:
 
55
  # Push solutions to the Hub
56
  push_to_hub: bool = False
57
 
 
58
  class PythonREPL:
59
  def __init__(self, timeout=5):
60
  self.timeout = timeout
 
74
  with open(temp_file_path, "w") as f:
75
  f.write(query)
76
 
77
+ try:
78
+ result = subprocess.run(
79
+ ["python3", temp_file_path],
80
+ capture_output=True,
81
+ check=False,
82
+ text=True,
83
+ timeout=self.timeout,
84
+ )
85
+ except subprocess.TimeoutExpired:
86
+ return False, f"Timed out after {self.timeout} seconds."
87
 
88
  if result.returncode == 0:
89
  output = result.stdout
 
119
  except TimeoutError:
120
  return False, f"Timed out after {self.timeout} seconds."
121
 
 
122
  def execute_completion(
123
  executor: PythonREPL,
124
  completion: str,
125
  return_status: bool = False,
126
  last_code_block: bool = False,
127
  ) -> str | Tuple[str, bool]:
128
+ # Extract python code blocks enclosed in triple backticks with language 'python'
129
  executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
130
 
131
+ if len(executions) == 0: # directly return COT result
132
  return completion, False if return_status else completion
133
  else:
134
  if last_code_block:
 
156
  success, output = executor(code)
157
  except TimeoutError as e:
158
  print("time out")
159
+ output = str(e)
160
 
161
  if not success and not return_status:
162
  output = ""
 
172
  else:
173
  return output
174
 
 
175
  def postprocess_completion(
176
  text: str, return_status: bool = False, last_code_block=False, timeout=5
177
  ) -> str | Tuple[str, bool]:
 
182
 
183
  return result
184
 
 
185
  def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]:
186
  return prompt.format(example["prompt"], "{}")
187
 
 
188
  def last_boxed_only_string(string):
189
  """
190
  Extracts the last LaTeX boxed or framed expression from a string.
 
221
 
222
  return retval
223
 
 
224
  def remove_boxed(s):
225
  """
226
  Removes the LaTeX boxed command, returning the content inside the braces.
 
240
  except Exception:
241
  return None
242
 
 
243
  def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
244
  """
245
  Extracts the answer from a LaTeX boxed expression within
 
260
  if answer is None:
261
  return None
262
  if strip_double_curly_brace:
263
+ match = re.match(r"^\{(.*)\}$", answer) # noqa: W605
264
  if match:
265
  answer = match.group(1)
266
  return answer
267
 
 
268
  def normalize_final_answer(final_answer: str) -> str:
269
  """
270
  Normalizes a final answer string by removing or replacing various LaTeX
 
277
 
278
  match = re.search(r"(.*?)Problem:", final_answer, flags=re.S)
279
  if match:
280
+ final_answer = match.group(1) # Return all text before 'Problem'
281
+
282
  """Normalize a final answer to a quantitative reasoning question."""
 
283
  SUBSTITUTIONS = [
284
  ("an ", ""),
285
  ("a ", ""),
 
370
  if "rac" in final_answer and "\\frac" not in final_answer:
371
  final_answer = final_answer.replace("rac", "\\frac")
372
 
373
+ final_answer = re.sub(r"(frac)([^{])(.)", r"frac{\2}{\3}", final_answer)
374
+ final_answer = re.sub(r"(sqrt)([^{])", r"sqrt{\2}", final_answer)
375
  final_answer = final_answer.replace("$", "")
376
 
377
  if final_answer.replace(",", "").isdigit():
 
379
 
380
  return final_answer
381
 
 
382
  def naive_parse(answer: str) -> str:
383
  """
384
  Extracts and returns the numeric digits from the input string, processing them in reverse order
385
  until a non-numeric character is encountered after encountering the first numeric character.
 
386
  Args:
387
  answer (str): The input string to parse.
 
388
  Returns:
389
  str: A string consisting of the numeric digits extracted from the input, in their original order.
 
390
  Example:
391
  >>> naive_parse("abc123def")
392
  '123'
 
409
  out = reversed(out)
410
  return "".join(out)
411
 
 
412
  def validate_answer_is_numeric(x: str | int | float) -> int:
413
  FLOAT_TOLERANCE = 0.2
414
  try:
 
420
  x = -1
421
  return x
422
 
 
423
  def filter_answers(answers: List[str]) -> List[int]:
424
  formatted_answers = [validate_answer_is_numeric(a) for a in answers]
425
 
 
431
  formatted_answers = [a for a in formatted_answers if a <= 999]
432
  return formatted_answers
433
 
 
434
  def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool:
435
  def do_answers_match(ref_answer: str, model_answer: str) -> bool:
436
  ref_sympy = parse_latex(ref_answer)
437
  model_sympy = parse_latex(model_answer)
438
  diff = simplify(ref_sympy - model_sympy)
439
+ return True if (-1e-12 < N(diff) < 1e-12) or diff.is_zero else False
440
 
441
  try:
442
  result = do_answers_match(ref_answer, model_answer)
 
445
  print(e)
446
  return False
447
 
 
448
  def check_string_match(ref_answer: str, model_answer: str) -> bool:
449
  try:
450
  return ref_answer == model_answer
 
452
  print(e)
453
  return False
454
 
 
455
  def check_answer(ref_answer: str, model_answer: str) -> bool:
456
  # check if strings are the same
457
  correct = check_string_match(ref_answer, model_answer)
 
465
 
466
  return False
467
 
468
+ # Configuration Parameters
469
  debug = False
470
+ model_id = "qwen2-7b-math-q8_0" # Update model ID
471
  revision = "main"
472
  system_prompt = "{}"
473
  validation_set = "kaggle-validation-set-medium"
 
504
  )
505
  print(f"=== Running submission with config ===\n\n{config}")
506
 
507
+ def generate(messages, temperature):
 
508
  """
509
  Generates a chat completion response by streaming data from the client chat model.
 
510
  This function streams the response from the client chat model and yields the content
511
  of the response chunk by chunk. If an error occurs, it yields the error message.
 
512
  Parameters:
513
+ messages (list of dict): The list of message dicts for the chat model.
514
+ temperature (float): The sampling temperature to use.
 
515
  Yields:
516
  tuple: A tuple containing the content of the response and a boolean flag indicating if an error occurred.
517
  If no error occurred, the boolean flag will be False and the content will be the response text.
518
  If an error occurred, the boolean flag will be True and the content will be the error message.
519
  """
520
+ try:
521
+ response = openai.ChatCompletion.create(
522
+ model=config.model_id,
523
+ messages=messages,
524
+ stream=True,
525
+ max_tokens=1024,
526
+ temperature=temperature,
527
+ )
528
+ except Exception as e:
529
+ yield str(e), True
530
+ return
531
+
532
+ for chunk in response:
533
+ if 'choices' in chunk:
534
+ choice = chunk['choices'][0]
535
+ if 'delta' in choice:
536
+ content = choice['delta'].get('content', '')
537
+ if content:
538
+ yield content, False
539
+ if choice.get('finish_reason') is not None:
540
  break
541
+ elif 'error' in chunk:
542
+ yield chunk['error']['message'], True
543
+ break
 
 
 
 
 
544
 
545
  def get_majority_text(data):
546
  from collections import Counter
 
557
  # Return the corresponding text in gen_texts
558
  return data["gen_texts"][majority_index]
559
 
 
560
  def extract_solution(text):
561
  # Split the text at "### Solution:"
562
  parts = text.split("### Solution:", 1)
 
567
  # Return an empty string if "### Solution:" is not found
568
  return ""
569
 
 
570
  def process_code(
571
  example: Dict[str, Any],
572
  config: Config,
 
578
 
579
  if num_python_blocks == 0:
580
  if restart_on_fail:
581
+ print("No code has been generated. Restarting generation.")
582
+ # Reset the text to the original
583
+ example["gen_texts"] = "## Solution:\n"
584
  else:
585
+ print("No code has been generated. Stopping.")
586
  example["should_prune"] = True
587
  example["has_code"] = False
588
  return example
589
 
590
+ if not gen_text.endswith("```output\n") and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]):
591
  num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL))
592
  if num_output_blocks == 0:
593
+ print("The model hallucinated the code answer.")
594
  example["should_prune"] = True
595
  return example
596
 
 
610
  return example
611
 
612
  if last_step:
613
+ # No point in continuing if we are at the last step
614
  return example
615
 
616
+ if not gen_text.endswith("```output\n"):
617
+ # Something else has gone wrong with the generation
618
+ print("Warning: Output block not found: ", gen_text[-40:])
619
  if restart_on_fail:
620
+ example["gen_texts"] = "## Solution:\n"
621
  else:
622
  example["should_prune"] = True
623
  return example
624
 
625
  code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True)
626
+ # Add the code result for the next round of generation
627
  TRUNCATION_LIMIT = 200
628
  if len(code_result) > TRUNCATION_LIMIT:
629
  code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)"
630
+ example["gen_texts"] = gen_text + f"```\n{code_result}\n```\n"
631
 
632
  return example
633
 
 
634
  def solve_problem(problem, temperature, progress=gr.Progress()):
635
  """
636
  yield token: string, stop: bool
637
  """
638
+ # Apply the system prompt template
639
+ problem_formatted = config.system_prompt.format(problem)
640
+ print(f"Problem: {problem_formatted}")
641
 
642
  sample = {
643
+ "problem": problem_formatted,
644
+ "ground_truth": "unknown",
645
  "text": "## Solution:\n",
646
+ "gen_texts": "## Solution:\n",
647
  "should_prune": False,
648
+ "problem_index": -1,
649
  "model_answers": "-1",
650
  "has_code": True,
651
+ "corrects": False,
652
  }
653
 
654
  for step in progress.tqdm(
655
  range(config.num_generations), desc="Generating candidates"
656
+ ):
657
 
658
+ step_response = sample["gen_texts"]
659
 
660
  messages = [
661
+ {"role": "system", "content": config.system_prompt.format(problem)},
662
+ {"role": "user", "content": sample["gen_texts"]},
663
  ]
664
 
665
+ for response_message, error in generate(messages, temperature):
666
+ if response_message:
667
+ step_response += response_message
668
+ yield preprocess_output(step_response)
 
669
  if error:
670
+ yield step_response, True
671
  return
672
 
673
+ sample["gen_texts"] = step_response
674
 
675
+ # Process the generated code
676
  sample = process_code(
677
  sample,
678
  config=config,
 
681
  )
682
  sample["gen_texts"] = sample["gen_texts"] + "\n"
683
 
684
+ # Extract any run code response
685
+ run_code_response = sample["gen_texts"].replace(step_response, "")
686
 
687
+ # Append the run code response if it exists
688
+ if run_code_response.strip():
689
+ step_response += run_code_response
690
+ yield preprocess_output(run_code_response)
691
 
692
  if sample["should_prune"]:
693
  break
694
 
695
  yield sample["gen_texts"], True
696
 
697
+ # Load the dataset
698
  example_data = datasets.load_dataset(
699
  "AI-MO/kaggle-validation-set-medium-extended",
700
  split="train",
701
  use_auth_token=os.environ.get("HF_DATASET_TOKEN", None),
702
  )
703
 
704
+ # Load CSS if available
705
+ css = ""
706
+ if os.path.exists("app.css"):
707
+ with open("app.css", "r") as f:
708
+ css = f.read()
709
 
710
  latex_delimiters = [
711
  {"left": "[", "right": "]", "display": True},
712
  ]
713
 
 
714
  def get_random_problem():
715
  example = random.choice(list(example_data))
716
  problem = example["problem"]
717
  return problem
718
 
 
719
  def update_example_problem():
720
  problem_example_text = get_random_problem()
721
  return problem_example_text, problem_example_text
722
 
 
723
  def clear():
724
  problem_example_text = get_random_problem()
725
  return "", 0.1, "", problem_example_text, problem_example_text
726
 
 
727
  def preprocess_output(text):
728
  return text.replace(r"\(", r"\\(").replace(r"\)", r"\\)")
729
 
 
730
  with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
731
  running_done = False
732
  btn_list = []
 
739
 
740
  with gr.Row(elem_classes="sub-title"):
741
  gr.HTML(
742
+ "<div>Demo of the <a href='https://huggingface.co/AI-MO/qwen2-7b-math-q8_0'>qwen2-7b-math-q8_0</a>. Example data are drawn randomly from AMC12, year 2022-2023.</div>",
743
  elem_classes="sub-title-content",
744
  )
745
 
746
  with gr.Row(elem_classes="main-area"):
747
  with gr.Column(scale=1, elem_classes="left"):
748
+ with gr.Row(elem_classes="problem-example-container"):
749
+ with gr.Blocks(elem_classes="problem-example-title"):
750
+ gr.HTML("Problem Example", elem_classes="problem-example-title-content")
751
 
752
  with gr.Blocks(elem_classes="action-container"):
753
  another_btn = gr.Button(
754
+ "Another Problem",
755
+ elem_classes="problem-example-another",
756
+ # Removed icon path to prevent errors
757
  )
758
+ copy_btn = gr.Button("Copy", elem_classes="problem-example-copy")
759
 
760
  problem_example = gr.HTML(
761
  problem_example_text,
762
+ elem_classes="problem-example-content",
763
  )
764
 
765
+ with gr.Row(elem_classes="problem-input-container"):
766
+ inp = gr.Textbox(placeholder="Enter your problem here...", label="Problem Input", lines=5)
767
  problem_markdown = gr.Markdown(
768
  visible=False,
769
  latex_delimiters=[
 
774
  )
775
 
776
  inp.change(fn=lambda text: text, inputs=[inp], outputs=[problem_markdown])
777
+ problem_input_ele_list.extend([inp, problem_markdown])
 
778
 
779
  with gr.Accordion("Advanced Options", open=False):
780
+ temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.1, label="Temperature")
781
 
782
  with gr.Row() as btn_area:
783
  btn_clear = gr.Button("Clear", elem_classes="clear-btn")
784
  btn_run = gr.Button("Run", elem_classes="run-btn")
785
+ btn_list.extend([btn_clear, btn_run])
 
786
 
787
  with gr.Column(scale=1, elem_classes="right"):
788
  gr.HTML("Solution", elem_classes="solution-title-content")
 
807
  running_done = True
808
  except Exception as e:
809
  running_done = True
810
+ yield str(e)
811
 
812
  def mount_run_btn(btn):
813
+ btn.click(fn=solve_problem_wrapper, inputs=[inp, temperature_slider], outputs=out)
814
  btn.click(get_running_btns, None, outputs=btn_list)
815
  btn.click(get_run_after_problem_input, None, outputs=problem_input_ele_list)
816
 
817
  def get_run_after_problem_input():
818
+ return gr.Textbox(placeholder="Enter your problem here...", label="Problem Input", lines=5, visible=False), gr.Markdown(
819
  visible=True,
820
  latex_delimiters=[
821
  {"left": "[", "right": "]", "display": True},
 
825
  )
826
 
827
  def get_init_problem_input():
828
+ return gr.Textbox(placeholder="Enter your problem here...", label="Problem Input", lines=5, visible=True), gr.Markdown(
829
  visible=False,
830
  latex_delimiters=[
831
  {"left": "[", "right": "]", "display": True},
 
855
 
856
  time.sleep(1)
857
 
858
+ copy_btn.click(fn=lambda _: gr.update(value=problem_example_text, interactive=True), inputs=None, outputs=inp)
859
 
860
  btn_clear.click(
861
  fn=clear,
862
  inputs=[],
863
  outputs=[
864
  inp,
865
+ temperature_slider,
866
  out,
867
  problem_example,
868
  problem_example_text_hidden,
 
892
  )
893
 
894
  if __name__ == "__main__":
895
+ demo.queue(default_concurrency_limit=5).launch(share=True)