Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
from dataclasses import dataclass
|
4 |
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
5 |
|
@@ -14,16 +13,13 @@ import time
|
|
14 |
from typing import Tuple, Dict, Any, List
|
15 |
from sympy import N, simplify
|
16 |
from sympy.parsing.latex import parse_latex
|
17 |
-
|
18 |
|
19 |
import base64
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
api_key=os.environ.get("HF_TOKEN"),
|
25 |
-
)
|
26 |
-
|
27 |
|
28 |
@dataclass
|
29 |
class Config:
|
@@ -59,7 +55,6 @@ class Config:
|
|
59 |
# Push solutions to the Hub
|
60 |
push_to_hub: bool = False
|
61 |
|
62 |
-
|
63 |
class PythonREPL:
|
64 |
def __init__(self, timeout=5):
|
65 |
self.timeout = timeout
|
@@ -79,13 +74,16 @@ class PythonREPL:
|
|
79 |
with open(temp_file_path, "w") as f:
|
80 |
f.write(query)
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
89 |
|
90 |
if result.returncode == 0:
|
91 |
output = result.stdout
|
@@ -121,17 +119,16 @@ class PythonREPL:
|
|
121 |
except TimeoutError:
|
122 |
return False, f"Timed out after {self.timeout} seconds."
|
123 |
|
124 |
-
|
125 |
def execute_completion(
|
126 |
executor: PythonREPL,
|
127 |
completion: str,
|
128 |
return_status: bool = False,
|
129 |
last_code_block: bool = False,
|
130 |
) -> str | Tuple[str, bool]:
|
131 |
-
#
|
132 |
executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
|
133 |
|
134 |
-
if len(executions) == 0: # directly return
|
135 |
return completion, False if return_status else completion
|
136 |
else:
|
137 |
if last_code_block:
|
@@ -159,7 +156,7 @@ def execute_completion(
|
|
159 |
success, output = executor(code)
|
160 |
except TimeoutError as e:
|
161 |
print("time out")
|
162 |
-
output = e
|
163 |
|
164 |
if not success and not return_status:
|
165 |
output = ""
|
@@ -175,7 +172,6 @@ def execute_completion(
|
|
175 |
else:
|
176 |
return output
|
177 |
|
178 |
-
|
179 |
def postprocess_completion(
|
180 |
text: str, return_status: bool = False, last_code_block=False, timeout=5
|
181 |
) -> str | Tuple[str, bool]:
|
@@ -186,11 +182,9 @@ def postprocess_completion(
|
|
186 |
|
187 |
return result
|
188 |
|
189 |
-
|
190 |
def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]:
|
191 |
return prompt.format(example["prompt"], "{}")
|
192 |
|
193 |
-
|
194 |
def last_boxed_only_string(string):
|
195 |
"""
|
196 |
Extracts the last LaTeX boxed or framed expression from a string.
|
@@ -227,7 +221,6 @@ def last_boxed_only_string(string):
|
|
227 |
|
228 |
return retval
|
229 |
|
230 |
-
|
231 |
def remove_boxed(s):
|
232 |
"""
|
233 |
Removes the LaTeX boxed command, returning the content inside the braces.
|
@@ -247,7 +240,6 @@ def remove_boxed(s):
|
|
247 |
except Exception:
|
248 |
return None
|
249 |
|
250 |
-
|
251 |
def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
|
252 |
"""
|
253 |
Extracts the answer from a LaTeX boxed expression within
|
@@ -268,12 +260,11 @@ def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
|
|
268 |
if answer is None:
|
269 |
return None
|
270 |
if strip_double_curly_brace:
|
271 |
-
match = re.match("^\{(.*)\}$", answer) # noqa: W605
|
272 |
if match:
|
273 |
answer = match.group(1)
|
274 |
return answer
|
275 |
|
276 |
-
|
277 |
def normalize_final_answer(final_answer: str) -> str:
|
278 |
"""
|
279 |
Normalizes a final answer string by removing or replacing various LaTeX
|
@@ -286,9 +277,9 @@ def normalize_final_answer(final_answer: str) -> str:
|
|
286 |
|
287 |
match = re.search(r"(.*?)Problem:", final_answer, flags=re.S)
|
288 |
if match:
|
289 |
-
final_answer = match.group(1) #
|
|
|
290 |
"""Normalize a final answer to a quantitative reasoning question."""
|
291 |
-
# final_answer = final_answer.split('=')[-1]
|
292 |
SUBSTITUTIONS = [
|
293 |
("an ", ""),
|
294 |
("a ", ""),
|
@@ -379,8 +370,8 @@ def normalize_final_answer(final_answer: str) -> str:
|
|
379 |
if "rac" in final_answer and "\\frac" not in final_answer:
|
380 |
final_answer = final_answer.replace("rac", "\\frac")
|
381 |
|
382 |
-
final_answer = re.sub(r"(frac)([^{])(.)", "frac{
|
383 |
-
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{
|
384 |
final_answer = final_answer.replace("$", "")
|
385 |
|
386 |
if final_answer.replace(",", "").isdigit():
|
@@ -388,18 +379,14 @@ def normalize_final_answer(final_answer: str) -> str:
|
|
388 |
|
389 |
return final_answer
|
390 |
|
391 |
-
|
392 |
def naive_parse(answer: str) -> str:
|
393 |
"""
|
394 |
Extracts and returns the numeric digits from the input string, processing them in reverse order
|
395 |
until a non-numeric character is encountered after encountering the first numeric character.
|
396 |
-
|
397 |
Args:
|
398 |
answer (str): The input string to parse.
|
399 |
-
|
400 |
Returns:
|
401 |
str: A string consisting of the numeric digits extracted from the input, in their original order.
|
402 |
-
|
403 |
Example:
|
404 |
>>> naive_parse("abc123def")
|
405 |
'123'
|
@@ -422,7 +409,6 @@ def naive_parse(answer: str) -> str:
|
|
422 |
out = reversed(out)
|
423 |
return "".join(out)
|
424 |
|
425 |
-
|
426 |
def validate_answer_is_numeric(x: str | int | float) -> int:
|
427 |
FLOAT_TOLERANCE = 0.2
|
428 |
try:
|
@@ -434,7 +420,6 @@ def validate_answer_is_numeric(x: str | int | float) -> int:
|
|
434 |
x = -1
|
435 |
return x
|
436 |
|
437 |
-
|
438 |
def filter_answers(answers: List[str]) -> List[int]:
|
439 |
formatted_answers = [validate_answer_is_numeric(a) for a in answers]
|
440 |
|
@@ -446,13 +431,12 @@ def filter_answers(answers: List[str]) -> List[int]:
|
|
446 |
formatted_answers = [a for a in formatted_answers if a <= 999]
|
447 |
return formatted_answers
|
448 |
|
449 |
-
|
450 |
def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool:
|
451 |
def do_answers_match(ref_answer: str, model_answer: str) -> bool:
|
452 |
ref_sympy = parse_latex(ref_answer)
|
453 |
model_sympy = parse_latex(model_answer)
|
454 |
diff = simplify(ref_sympy - model_sympy)
|
455 |
-
return True if -1e-12 < N(diff) < 1e-12 or diff.is_zero else False
|
456 |
|
457 |
try:
|
458 |
result = do_answers_match(ref_answer, model_answer)
|
@@ -461,7 +445,6 @@ def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool:
|
|
461 |
print(e)
|
462 |
return False
|
463 |
|
464 |
-
|
465 |
def check_string_match(ref_answer: str, model_answer: str) -> bool:
|
466 |
try:
|
467 |
return ref_answer == model_answer
|
@@ -469,7 +452,6 @@ def check_string_match(ref_answer: str, model_answer: str) -> bool:
|
|
469 |
print(e)
|
470 |
return False
|
471 |
|
472 |
-
|
473 |
def check_answer(ref_answer: str, model_answer: str) -> bool:
|
474 |
# check if strings are the same
|
475 |
correct = check_string_match(ref_answer, model_answer)
|
@@ -483,9 +465,9 @@ def check_answer(ref_answer: str, model_answer: str) -> bool:
|
|
483 |
|
484 |
return False
|
485 |
|
486 |
-
|
487 |
debug = False
|
488 |
-
model_id = "
|
489 |
revision = "main"
|
490 |
system_prompt = "{}"
|
491 |
validation_set = "kaggle-validation-set-medium"
|
@@ -522,52 +504,43 @@ config = Config(
|
|
522 |
)
|
523 |
print(f"=== Running submission with config ===\n\n{config}")
|
524 |
|
525 |
-
|
526 |
-
def generate(message, temperature):
|
527 |
"""
|
528 |
Generates a chat completion response by streaming data from the client chat model.
|
529 |
-
|
530 |
This function streams the response from the client chat model and yields the content
|
531 |
of the response chunk by chunk. If an error occurs, it yields the error message.
|
532 |
-
|
533 |
Parameters:
|
534 |
-
|
535 |
-
temperature (float): The sampling temperature to use.
|
536 |
-
|
537 |
Yields:
|
538 |
tuple: A tuple containing the content of the response and a boolean flag indicating if an error occurred.
|
539 |
If no error occurred, the boolean flag will be False and the content will be the response text.
|
540 |
If an error occurred, the boolean flag will be True and the content will be the error message.
|
541 |
"""
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
break
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
yield content, False
|
567 |
-
except Exception as e:
|
568 |
-
print(f"func: generate error occurred\njson:{chune_json}\nerror:{e}")
|
569 |
-
yield "", True
|
570 |
-
|
571 |
|
572 |
def get_majority_text(data):
|
573 |
from collections import Counter
|
@@ -584,7 +557,6 @@ def get_majority_text(data):
|
|
584 |
# Return the corresponding text in gen_texts
|
585 |
return data["gen_texts"][majority_index]
|
586 |
|
587 |
-
|
588 |
def extract_solution(text):
|
589 |
# Split the text at "### Solution:"
|
590 |
parts = text.split("### Solution:", 1)
|
@@ -595,7 +567,6 @@ def extract_solution(text):
|
|
595 |
# Return an empty string if "### Solution:" is not found
|
596 |
return ""
|
597 |
|
598 |
-
|
599 |
def process_code(
|
600 |
example: Dict[str, Any],
|
601 |
config: Config,
|
@@ -607,19 +578,19 @@ def process_code(
|
|
607 |
|
608 |
if num_python_blocks == 0:
|
609 |
if restart_on_fail:
|
610 |
-
print("
|
611 |
-
#
|
612 |
-
example["gen_texts"] =
|
613 |
else:
|
614 |
-
print("
|
615 |
example["should_prune"] = True
|
616 |
example["has_code"] = False
|
617 |
return example
|
618 |
|
619 |
-
if gen_text
|
620 |
num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL))
|
621 |
if num_output_blocks == 0:
|
622 |
-
print("
|
623 |
example["should_prune"] = True
|
624 |
return example
|
625 |
|
@@ -639,70 +610,69 @@ def process_code(
|
|
639 |
return example
|
640 |
|
641 |
if last_step:
|
642 |
-
#
|
643 |
return example
|
644 |
|
645 |
-
if gen_text
|
646 |
-
#
|
647 |
-
print("
|
648 |
if restart_on_fail:
|
649 |
-
example["gen_texts"] =
|
650 |
else:
|
651 |
example["should_prune"] = True
|
652 |
return example
|
653 |
|
654 |
code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True)
|
655 |
-
#
|
656 |
TRUNCATION_LIMIT = 200
|
657 |
if len(code_result) > TRUNCATION_LIMIT:
|
658 |
code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)"
|
659 |
-
example["gen_texts"] = gen_text + f"{code_result}\n
|
660 |
|
661 |
return example
|
662 |
|
663 |
-
|
664 |
def solve_problem(problem, temperature, progress=gr.Progress()):
|
665 |
"""
|
666 |
yield token: string, stop: bool
|
667 |
"""
|
668 |
-
|
669 |
-
|
|
|
670 |
|
671 |
sample = {
|
672 |
-
"problem":
|
673 |
-
"ground_truth": "unknown",
|
674 |
"text": "## Solution:\n",
|
675 |
-
"gen_texts": "## Solution:\n",
|
676 |
"should_prune": False,
|
677 |
-
"problem_index": -1,
|
678 |
"model_answers": "-1",
|
679 |
"has_code": True,
|
680 |
-
"corrects": False,
|
681 |
}
|
682 |
|
683 |
for step in progress.tqdm(
|
684 |
range(config.num_generations), desc="Generating candidates"
|
685 |
-
):
|
686 |
|
687 |
-
|
688 |
|
689 |
messages = [
|
690 |
-
{"role": "
|
691 |
-
{"role": "
|
692 |
]
|
693 |
|
694 |
-
for
|
695 |
-
if
|
696 |
-
|
697 |
-
yield
|
698 |
-
|
699 |
if error:
|
700 |
-
yield
|
701 |
return
|
702 |
|
703 |
-
sample["gen_texts"] =
|
704 |
|
705 |
-
#
|
706 |
sample = process_code(
|
707 |
sample,
|
708 |
config=config,
|
@@ -711,55 +681,52 @@ def solve_problem(problem, temperature, progress=gr.Progress()):
|
|
711 |
)
|
712 |
sample["gen_texts"] = sample["gen_texts"] + "\n"
|
713 |
|
714 |
-
|
|
|
715 |
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
|
721 |
if sample["should_prune"]:
|
722 |
break
|
723 |
|
724 |
yield sample["gen_texts"], True
|
725 |
|
726 |
-
|
727 |
example_data = datasets.load_dataset(
|
728 |
"AI-MO/kaggle-validation-set-medium-extended",
|
729 |
split="train",
|
730 |
use_auth_token=os.environ.get("HF_DATASET_TOKEN", None),
|
731 |
)
|
732 |
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
|
|
737 |
|
738 |
latex_delimiters = [
|
739 |
{"left": "[", "right": "]", "display": True},
|
740 |
]
|
741 |
|
742 |
-
|
743 |
def get_random_problem():
|
744 |
example = random.choice(list(example_data))
|
745 |
problem = example["problem"]
|
746 |
return problem
|
747 |
|
748 |
-
|
749 |
def update_example_problem():
|
750 |
problem_example_text = get_random_problem()
|
751 |
return problem_example_text, problem_example_text
|
752 |
|
753 |
-
|
754 |
def clear():
|
755 |
problem_example_text = get_random_problem()
|
756 |
return "", 0.1, "", problem_example_text, problem_example_text
|
757 |
|
758 |
-
|
759 |
def preprocess_output(text):
|
760 |
return text.replace(r"\(", r"\\(").replace(r"\)", r"\\)")
|
761 |
|
762 |
-
|
763 |
with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
|
764 |
running_done = False
|
765 |
btn_list = []
|
@@ -772,31 +739,31 @@ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
|
|
772 |
|
773 |
with gr.Row(elem_classes="sub-title"):
|
774 |
gr.HTML(
|
775 |
-
"<div>Demo of the <a href='https://huggingface.co/AI-MO/
|
776 |
elem_classes="sub-title-content",
|
777 |
)
|
778 |
|
779 |
with gr.Row(elem_classes="main-area"):
|
780 |
with gr.Column(scale=1, elem_classes="left"):
|
781 |
-
with gr.Row(elem_classes="
|
782 |
-
with gr.Blocks(elem_classes="
|
783 |
-
gr.HTML("Problem
|
784 |
|
785 |
with gr.Blocks(elem_classes="action-container"):
|
786 |
another_btn = gr.Button(
|
787 |
-
"",
|
788 |
-
elem_classes="
|
789 |
-
icon
|
790 |
)
|
791 |
-
copy_btn = gr.Button("Copy", elem_classes="
|
792 |
|
793 |
problem_example = gr.HTML(
|
794 |
problem_example_text,
|
795 |
-
elem_classes="
|
796 |
)
|
797 |
|
798 |
-
with gr.Row(elem_classes="
|
799 |
-
inp = gr.Textbox(placeholder="
|
800 |
problem_markdown = gr.Markdown(
|
801 |
visible=False,
|
802 |
latex_delimiters=[
|
@@ -807,17 +774,15 @@ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
|
|
807 |
)
|
808 |
|
809 |
inp.change(fn=lambda text: text, inputs=[inp], outputs=[problem_markdown])
|
810 |
-
problem_input_ele_list.
|
811 |
-
problem_input_ele_list.append(problem_markdown)
|
812 |
|
813 |
with gr.Accordion("Advanced Options", open=False):
|
814 |
-
|
815 |
|
816 |
with gr.Row() as btn_area:
|
817 |
btn_clear = gr.Button("Clear", elem_classes="clear-btn")
|
818 |
btn_run = gr.Button("Run", elem_classes="run-btn")
|
819 |
-
btn_list.
|
820 |
-
btn_list.append(btn_run)
|
821 |
|
822 |
with gr.Column(scale=1, elem_classes="right"):
|
823 |
gr.HTML("Solution", elem_classes="solution-title-content")
|
@@ -842,15 +807,15 @@ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
|
|
842 |
running_done = True
|
843 |
except Exception as e:
|
844 |
running_done = True
|
845 |
-
|
846 |
|
847 |
def mount_run_btn(btn):
|
848 |
-
btn.click(fn=solve_problem_wrapper, inputs=[inp,
|
849 |
btn.click(get_running_btns, None, outputs=btn_list)
|
850 |
btn.click(get_run_after_problem_input, None, outputs=problem_input_ele_list)
|
851 |
|
852 |
def get_run_after_problem_input():
|
853 |
-
return gr.Textbox(placeholder="
|
854 |
visible=True,
|
855 |
latex_delimiters=[
|
856 |
{"left": "[", "right": "]", "display": True},
|
@@ -860,7 +825,7 @@ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
|
|
860 |
)
|
861 |
|
862 |
def get_init_problem_input():
|
863 |
-
return gr.Textbox(placeholder="
|
864 |
visible=False,
|
865 |
latex_delimiters=[
|
866 |
{"left": "[", "right": "]", "display": True},
|
@@ -890,14 +855,14 @@ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
|
|
890 |
|
891 |
time.sleep(1)
|
892 |
|
893 |
-
copy_btn.click(fn=lambda
|
894 |
|
895 |
btn_clear.click(
|
896 |
fn=clear,
|
897 |
inputs=[],
|
898 |
outputs=[
|
899 |
inp,
|
900 |
-
|
901 |
out,
|
902 |
problem_example,
|
903 |
problem_example_text_hidden,
|
@@ -927,4 +892,4 @@ with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
|
|
927 |
)
|
928 |
|
929 |
if __name__ == "__main__":
|
930 |
-
demo.queue(default_concurrency_limit=5).launch()
|
|
|
1 |
import gradio as gr
|
|
|
2 |
from dataclasses import dataclass
|
3 |
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
4 |
|
|
|
13 |
from typing import Tuple, Dict, Any, List
|
14 |
from sympy import N, simplify
|
15 |
from sympy.parsing.latex import parse_latex
|
16 |
+
import openai
|
17 |
|
18 |
import base64
|
19 |
|
20 |
+
# Initialize OpenAI client to use local API
|
21 |
+
openai.api_base = os.environ.get("SERVER_URL", "http://0.0.0.0:6061")
|
22 |
+
openai.api_key = os.environ.get("HF_TOKEN", "") # If no key needed, set empty string
|
|
|
|
|
|
|
23 |
|
24 |
@dataclass
|
25 |
class Config:
|
|
|
55 |
# Push solutions to the Hub
|
56 |
push_to_hub: bool = False
|
57 |
|
|
|
58 |
class PythonREPL:
|
59 |
def __init__(self, timeout=5):
|
60 |
self.timeout = timeout
|
|
|
74 |
with open(temp_file_path, "w") as f:
|
75 |
f.write(query)
|
76 |
|
77 |
+
try:
|
78 |
+
result = subprocess.run(
|
79 |
+
["python3", temp_file_path],
|
80 |
+
capture_output=True,
|
81 |
+
check=False,
|
82 |
+
text=True,
|
83 |
+
timeout=self.timeout,
|
84 |
+
)
|
85 |
+
except subprocess.TimeoutExpired:
|
86 |
+
return False, f"Timed out after {self.timeout} seconds."
|
87 |
|
88 |
if result.returncode == 0:
|
89 |
output = result.stdout
|
|
|
119 |
except TimeoutError:
|
120 |
return False, f"Timed out after {self.timeout} seconds."
|
121 |
|
|
|
122 |
def execute_completion(
|
123 |
executor: PythonREPL,
|
124 |
completion: str,
|
125 |
return_status: bool = False,
|
126 |
last_code_block: bool = False,
|
127 |
) -> str | Tuple[str, bool]:
|
128 |
+
# Extract python code blocks enclosed in triple backticks with language 'python'
|
129 |
executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
|
130 |
|
131 |
+
if len(executions) == 0: # directly return COT result
|
132 |
return completion, False if return_status else completion
|
133 |
else:
|
134 |
if last_code_block:
|
|
|
156 |
success, output = executor(code)
|
157 |
except TimeoutError as e:
|
158 |
print("time out")
|
159 |
+
output = str(e)
|
160 |
|
161 |
if not success and not return_status:
|
162 |
output = ""
|
|
|
172 |
else:
|
173 |
return output
|
174 |
|
|
|
175 |
def postprocess_completion(
|
176 |
text: str, return_status: bool = False, last_code_block=False, timeout=5
|
177 |
) -> str | Tuple[str, bool]:
|
|
|
182 |
|
183 |
return result
|
184 |
|
|
|
185 |
def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]:
|
186 |
return prompt.format(example["prompt"], "{}")
|
187 |
|
|
|
188 |
def last_boxed_only_string(string):
|
189 |
"""
|
190 |
Extracts the last LaTeX boxed or framed expression from a string.
|
|
|
221 |
|
222 |
return retval
|
223 |
|
|
|
224 |
def remove_boxed(s):
|
225 |
"""
|
226 |
Removes the LaTeX boxed command, returning the content inside the braces.
|
|
|
240 |
except Exception:
|
241 |
return None
|
242 |
|
|
|
243 |
def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
|
244 |
"""
|
245 |
Extracts the answer from a LaTeX boxed expression within
|
|
|
260 |
if answer is None:
|
261 |
return None
|
262 |
if strip_double_curly_brace:
|
263 |
+
match = re.match(r"^\{(.*)\}$", answer) # noqa: W605
|
264 |
if match:
|
265 |
answer = match.group(1)
|
266 |
return answer
|
267 |
|
|
|
268 |
def normalize_final_answer(final_answer: str) -> str:
|
269 |
"""
|
270 |
Normalizes a final answer string by removing or replacing various LaTeX
|
|
|
277 |
|
278 |
match = re.search(r"(.*?)Problem:", final_answer, flags=re.S)
|
279 |
if match:
|
280 |
+
final_answer = match.group(1) # Return all text before 'Problem'
|
281 |
+
|
282 |
"""Normalize a final answer to a quantitative reasoning question."""
|
|
|
283 |
SUBSTITUTIONS = [
|
284 |
("an ", ""),
|
285 |
("a ", ""),
|
|
|
370 |
if "rac" in final_answer and "\\frac" not in final_answer:
|
371 |
final_answer = final_answer.replace("rac", "\\frac")
|
372 |
|
373 |
+
final_answer = re.sub(r"(frac)([^{])(.)", r"frac{\2}{\3}", final_answer)
|
374 |
+
final_answer = re.sub(r"(sqrt)([^{])", r"sqrt{\2}", final_answer)
|
375 |
final_answer = final_answer.replace("$", "")
|
376 |
|
377 |
if final_answer.replace(",", "").isdigit():
|
|
|
379 |
|
380 |
return final_answer
|
381 |
|
|
|
382 |
def naive_parse(answer: str) -> str:
|
383 |
"""
|
384 |
Extracts and returns the numeric digits from the input string, processing them in reverse order
|
385 |
until a non-numeric character is encountered after encountering the first numeric character.
|
|
|
386 |
Args:
|
387 |
answer (str): The input string to parse.
|
|
|
388 |
Returns:
|
389 |
str: A string consisting of the numeric digits extracted from the input, in their original order.
|
|
|
390 |
Example:
|
391 |
>>> naive_parse("abc123def")
|
392 |
'123'
|
|
|
409 |
out = reversed(out)
|
410 |
return "".join(out)
|
411 |
|
|
|
412 |
def validate_answer_is_numeric(x: str | int | float) -> int:
|
413 |
FLOAT_TOLERANCE = 0.2
|
414 |
try:
|
|
|
420 |
x = -1
|
421 |
return x
|
422 |
|
|
|
423 |
def filter_answers(answers: List[str]) -> List[int]:
|
424 |
formatted_answers = [validate_answer_is_numeric(a) for a in answers]
|
425 |
|
|
|
431 |
formatted_answers = [a for a in formatted_answers if a <= 999]
|
432 |
return formatted_answers
|
433 |
|
|
|
434 |
def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool:
|
435 |
def do_answers_match(ref_answer: str, model_answer: str) -> bool:
|
436 |
ref_sympy = parse_latex(ref_answer)
|
437 |
model_sympy = parse_latex(model_answer)
|
438 |
diff = simplify(ref_sympy - model_sympy)
|
439 |
+
return True if (-1e-12 < N(diff) < 1e-12) or diff.is_zero else False
|
440 |
|
441 |
try:
|
442 |
result = do_answers_match(ref_answer, model_answer)
|
|
|
445 |
print(e)
|
446 |
return False
|
447 |
|
|
|
448 |
def check_string_match(ref_answer: str, model_answer: str) -> bool:
|
449 |
try:
|
450 |
return ref_answer == model_answer
|
|
|
452 |
print(e)
|
453 |
return False
|
454 |
|
|
|
455 |
def check_answer(ref_answer: str, model_answer: str) -> bool:
|
456 |
# check if strings are the same
|
457 |
correct = check_string_match(ref_answer, model_answer)
|
|
|
465 |
|
466 |
return False
|
467 |
|
468 |
+
# Configuration Parameters
|
469 |
debug = False
|
470 |
+
model_id = "qwen2-7b-math-q8_0" # Update model ID
|
471 |
revision = "main"
|
472 |
system_prompt = "{}"
|
473 |
validation_set = "kaggle-validation-set-medium"
|
|
|
504 |
)
|
505 |
print(f"=== Running submission with config ===\n\n{config}")
|
506 |
|
507 |
+
def generate(messages, temperature):
|
|
|
508 |
"""
|
509 |
Generates a chat completion response by streaming data from the client chat model.
|
|
|
510 |
This function streams the response from the client chat model and yields the content
|
511 |
of the response chunk by chunk. If an error occurs, it yields the error message.
|
|
|
512 |
Parameters:
|
513 |
+
messages (list of dict): The list of message dicts for the chat model.
|
514 |
+
temperature (float): The sampling temperature to use.
|
|
|
515 |
Yields:
|
516 |
tuple: A tuple containing the content of the response and a boolean flag indicating if an error occurred.
|
517 |
If no error occurred, the boolean flag will be False and the content will be the response text.
|
518 |
If an error occurred, the boolean flag will be True and the content will be the error message.
|
519 |
"""
|
520 |
+
try:
|
521 |
+
response = openai.ChatCompletion.create(
|
522 |
+
model=config.model_id,
|
523 |
+
messages=messages,
|
524 |
+
stream=True,
|
525 |
+
max_tokens=1024,
|
526 |
+
temperature=temperature,
|
527 |
+
)
|
528 |
+
except Exception as e:
|
529 |
+
yield str(e), True
|
530 |
+
return
|
531 |
+
|
532 |
+
for chunk in response:
|
533 |
+
if 'choices' in chunk:
|
534 |
+
choice = chunk['choices'][0]
|
535 |
+
if 'delta' in choice:
|
536 |
+
content = choice['delta'].get('content', '')
|
537 |
+
if content:
|
538 |
+
yield content, False
|
539 |
+
if choice.get('finish_reason') is not None:
|
540 |
break
|
541 |
+
elif 'error' in chunk:
|
542 |
+
yield chunk['error']['message'], True
|
543 |
+
break
|
|
|
|
|
|
|
|
|
|
|
544 |
|
545 |
def get_majority_text(data):
|
546 |
from collections import Counter
|
|
|
557 |
# Return the corresponding text in gen_texts
|
558 |
return data["gen_texts"][majority_index]
|
559 |
|
|
|
560 |
def extract_solution(text):
|
561 |
# Split the text at "### Solution:"
|
562 |
parts = text.split("### Solution:", 1)
|
|
|
567 |
# Return an empty string if "### Solution:" is not found
|
568 |
return ""
|
569 |
|
|
|
570 |
def process_code(
|
571 |
example: Dict[str, Any],
|
572 |
config: Config,
|
|
|
578 |
|
579 |
if num_python_blocks == 0:
|
580 |
if restart_on_fail:
|
581 |
+
print("No code has been generated. Restarting generation.")
|
582 |
+
# Reset the text to the original
|
583 |
+
example["gen_texts"] = "## Solution:\n"
|
584 |
else:
|
585 |
+
print("No code has been generated. Stopping.")
|
586 |
example["should_prune"] = True
|
587 |
example["has_code"] = False
|
588 |
return example
|
589 |
|
590 |
+
if not gen_text.endswith("```output\n") and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]):
|
591 |
num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL))
|
592 |
if num_output_blocks == 0:
|
593 |
+
print("The model hallucinated the code answer.")
|
594 |
example["should_prune"] = True
|
595 |
return example
|
596 |
|
|
|
610 |
return example
|
611 |
|
612 |
if last_step:
|
613 |
+
# No point in continuing if we are at the last step
|
614 |
return example
|
615 |
|
616 |
+
if not gen_text.endswith("```output\n"):
|
617 |
+
# Something else has gone wrong with the generation
|
618 |
+
print("Warning: Output block not found: ", gen_text[-40:])
|
619 |
if restart_on_fail:
|
620 |
+
example["gen_texts"] = "## Solution:\n"
|
621 |
else:
|
622 |
example["should_prune"] = True
|
623 |
return example
|
624 |
|
625 |
code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True)
|
626 |
+
# Add the code result for the next round of generation
|
627 |
TRUNCATION_LIMIT = 200
|
628 |
if len(code_result) > TRUNCATION_LIMIT:
|
629 |
code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)"
|
630 |
+
example["gen_texts"] = gen_text + f"```\n{code_result}\n```\n"
|
631 |
|
632 |
return example
|
633 |
|
|
|
634 |
def solve_problem(problem, temperature, progress=gr.Progress()):
|
635 |
"""
|
636 |
yield token: string, stop: bool
|
637 |
"""
|
638 |
+
# Apply the system prompt template
|
639 |
+
problem_formatted = config.system_prompt.format(problem)
|
640 |
+
print(f"Problem: {problem_formatted}")
|
641 |
|
642 |
sample = {
|
643 |
+
"problem": problem_formatted,
|
644 |
+
"ground_truth": "unknown",
|
645 |
"text": "## Solution:\n",
|
646 |
+
"gen_texts": "## Solution:\n",
|
647 |
"should_prune": False,
|
648 |
+
"problem_index": -1,
|
649 |
"model_answers": "-1",
|
650 |
"has_code": True,
|
651 |
+
"corrects": False,
|
652 |
}
|
653 |
|
654 |
for step in progress.tqdm(
|
655 |
range(config.num_generations), desc="Generating candidates"
|
656 |
+
):
|
657 |
|
658 |
+
step_response = sample["gen_texts"]
|
659 |
|
660 |
messages = [
|
661 |
+
{"role": "system", "content": config.system_prompt.format(problem)},
|
662 |
+
{"role": "user", "content": sample["gen_texts"]},
|
663 |
]
|
664 |
|
665 |
+
for response_message, error in generate(messages, temperature):
|
666 |
+
if response_message:
|
667 |
+
step_response += response_message
|
668 |
+
yield preprocess_output(step_response)
|
|
|
669 |
if error:
|
670 |
+
yield step_response, True
|
671 |
return
|
672 |
|
673 |
+
sample["gen_texts"] = step_response
|
674 |
|
675 |
+
# Process the generated code
|
676 |
sample = process_code(
|
677 |
sample,
|
678 |
config=config,
|
|
|
681 |
)
|
682 |
sample["gen_texts"] = sample["gen_texts"] + "\n"
|
683 |
|
684 |
+
# Extract any run code response
|
685 |
+
run_code_response = sample["gen_texts"].replace(step_response, "")
|
686 |
|
687 |
+
# Append the run code response if it exists
|
688 |
+
if run_code_response.strip():
|
689 |
+
step_response += run_code_response
|
690 |
+
yield preprocess_output(run_code_response)
|
691 |
|
692 |
if sample["should_prune"]:
|
693 |
break
|
694 |
|
695 |
yield sample["gen_texts"], True
|
696 |
|
697 |
+
# Load the dataset
|
698 |
example_data = datasets.load_dataset(
|
699 |
"AI-MO/kaggle-validation-set-medium-extended",
|
700 |
split="train",
|
701 |
use_auth_token=os.environ.get("HF_DATASET_TOKEN", None),
|
702 |
)
|
703 |
|
704 |
+
# Load CSS if available
|
705 |
+
css = ""
|
706 |
+
if os.path.exists("app.css"):
|
707 |
+
with open("app.css", "r") as f:
|
708 |
+
css = f.read()
|
709 |
|
710 |
latex_delimiters = [
|
711 |
{"left": "[", "right": "]", "display": True},
|
712 |
]
|
713 |
|
|
|
714 |
def get_random_problem():
|
715 |
example = random.choice(list(example_data))
|
716 |
problem = example["problem"]
|
717 |
return problem
|
718 |
|
|
|
719 |
def update_example_problem():
|
720 |
problem_example_text = get_random_problem()
|
721 |
return problem_example_text, problem_example_text
|
722 |
|
|
|
723 |
def clear():
|
724 |
problem_example_text = get_random_problem()
|
725 |
return "", 0.1, "", problem_example_text, problem_example_text
|
726 |
|
|
|
727 |
def preprocess_output(text):
|
728 |
return text.replace(r"\(", r"\\(").replace(r"\)", r"\\)")
|
729 |
|
|
|
730 |
with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
|
731 |
running_done = False
|
732 |
btn_list = []
|
|
|
739 |
|
740 |
with gr.Row(elem_classes="sub-title"):
|
741 |
gr.HTML(
|
742 |
+
"<div>Demo of the <a href='https://huggingface.co/AI-MO/qwen2-7b-math-q8_0'>qwen2-7b-math-q8_0</a>. Example data are drawn randomly from AMC12, year 2022-2023.</div>",
|
743 |
elem_classes="sub-title-content",
|
744 |
)
|
745 |
|
746 |
with gr.Row(elem_classes="main-area"):
|
747 |
with gr.Column(scale=1, elem_classes="left"):
|
748 |
+
with gr.Row(elem_classes="problem-example-container"):
|
749 |
+
with gr.Blocks(elem_classes="problem-example-title"):
|
750 |
+
gr.HTML("Problem Example", elem_classes="problem-example-title-content")
|
751 |
|
752 |
with gr.Blocks(elem_classes="action-container"):
|
753 |
another_btn = gr.Button(
|
754 |
+
"Another Problem",
|
755 |
+
elem_classes="problem-example-another",
|
756 |
+
# Removed icon path to prevent errors
|
757 |
)
|
758 |
+
copy_btn = gr.Button("Copy", elem_classes="problem-example-copy")
|
759 |
|
760 |
problem_example = gr.HTML(
|
761 |
problem_example_text,
|
762 |
+
elem_classes="problem-example-content",
|
763 |
)
|
764 |
|
765 |
+
with gr.Row(elem_classes="problem-input-container"):
|
766 |
+
inp = gr.Textbox(placeholder="Enter your problem here...", label="Problem Input", lines=5)
|
767 |
problem_markdown = gr.Markdown(
|
768 |
visible=False,
|
769 |
latex_delimiters=[
|
|
|
774 |
)
|
775 |
|
776 |
inp.change(fn=lambda text: text, inputs=[inp], outputs=[problem_markdown])
|
777 |
+
problem_input_ele_list.extend([inp, problem_markdown])
|
|
|
778 |
|
779 |
with gr.Accordion("Advanced Options", open=False):
|
780 |
+
temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.1, label="Temperature")
|
781 |
|
782 |
with gr.Row() as btn_area:
|
783 |
btn_clear = gr.Button("Clear", elem_classes="clear-btn")
|
784 |
btn_run = gr.Button("Run", elem_classes="run-btn")
|
785 |
+
btn_list.extend([btn_clear, btn_run])
|
|
|
786 |
|
787 |
with gr.Column(scale=1, elem_classes="right"):
|
788 |
gr.HTML("Solution", elem_classes="solution-title-content")
|
|
|
807 |
running_done = True
|
808 |
except Exception as e:
|
809 |
running_done = True
|
810 |
+
yield str(e)
|
811 |
|
812 |
def mount_run_btn(btn):
|
813 |
+
btn.click(fn=solve_problem_wrapper, inputs=[inp, temperature_slider], outputs=out)
|
814 |
btn.click(get_running_btns, None, outputs=btn_list)
|
815 |
btn.click(get_run_after_problem_input, None, outputs=problem_input_ele_list)
|
816 |
|
817 |
def get_run_after_problem_input():
|
818 |
+
return gr.Textbox(placeholder="Enter your problem here...", label="Problem Input", lines=5, visible=False), gr.Markdown(
|
819 |
visible=True,
|
820 |
latex_delimiters=[
|
821 |
{"left": "[", "right": "]", "display": True},
|
|
|
825 |
)
|
826 |
|
827 |
def get_init_problem_input():
|
828 |
+
return gr.Textbox(placeholder="Enter your problem here...", label="Problem Input", lines=5, visible=True), gr.Markdown(
|
829 |
visible=False,
|
830 |
latex_delimiters=[
|
831 |
{"left": "[", "right": "]", "display": True},
|
|
|
855 |
|
856 |
time.sleep(1)
|
857 |
|
858 |
+
copy_btn.click(fn=lambda _: gr.update(value=problem_example_text, interactive=True), inputs=None, outputs=inp)
|
859 |
|
860 |
btn_clear.click(
|
861 |
fn=clear,
|
862 |
inputs=[],
|
863 |
outputs=[
|
864 |
inp,
|
865 |
+
temperature_slider,
|
866 |
out,
|
867 |
problem_example,
|
868 |
problem_example_text_hidden,
|
|
|
892 |
)
|
893 |
|
894 |
if __name__ == "__main__":
|
895 |
+
demo.queue(default_concurrency_limit=5).launch(share=True)
|