import re
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from unidecode import unidecode
from gradio_i18n import gettext, Translate
from datasets import load_dataset
from style import custom_css, solution_style, letter_style, definition_style
template = """<|user|>
Risolvi gli indizi tra parentesi per ottenere una prima lettura, e usa la chiave di lettura per ottenere la soluzione del rebus.
Rebus: {rebus}
Chiave di lettura: {key}<|end|>
<|assistant|>"""
eureka5_test_data = load_dataset(
'gsarti/eureka-rebus', 'llm_sft',
data_files=["id_test.jsonl", "ood_test.jsonl"],
split = "train",
revision="1.0"
)
OUTPUTS_BASE_URL = "https://raw.githubusercontent.com/gsarti/verbalized-rebus/main/outputs/"
model_outputs = load_dataset(
"csv",
data_files={
"gpt4": OUTPUTS_BASE_URL + "prompted_models/gpt4o_results.csv",
"claude3_5_sonnet": OUTPUTS_BASE_URL + "prompted_models/claude3_5_sonnet_results.csv",
"llama3_70b": OUTPUTS_BASE_URL + "prompted_models/llama3_70b_results.csv",
"qwen_72b": OUTPUTS_BASE_URL + "prompted_models/qwen_72b_results.csv",
"phi3_mini": OUTPUTS_BASE_URL + "phi3_mini/phi3_mini_results_step_5070.csv",
"gemma2": OUTPUTS_BASE_URL + "gemma2_2b/gemma2_2b_results_step_5070.csv",
"llama3_1_8b": OUTPUTS_BASE_URL + "llama3.1_8b/llama3.1_8b_results_step_5070.csv"
}
)
def extract(span_text: str, tag: str = "span") -> str:
pattern = rf'<{tag}[^>]*>(.*?)<\/{tag}>'
matches = re.findall(pattern, span_text)
return "".join(matches) if matches else ""
def parse_rebus(ex_idx: int):
i = eureka5_test_data[ex_idx - 1]["conversations"][0]["value"]
o = eureka5_test_data[ex_idx - 1]["conversations"][1]["value"]
rebus = i.split("Rebus: ")[1].split("\n")[0]
rebus_letters = re.sub(r"\[.*?\]", "<<<>>>", rebus)
rebus_letters = re.sub(r"([a-zA-Z]+)", rf"""{letter_style}\1""", rebus_letters)
fp_empty = rebus_letters.replace("<<<>>>", f"{definition_style}___")
key = i.split("Chiave di lettura: ")[1].split("\n")[0]
key_split = key
key_highlighted = re.sub(r"(\d+)", rf"""{solution_style}\1""", key)
fp_elements = re.findall(r"- (.*) = (.*)", o)
definitions = [x[0] for x in fp_elements if x[0].startswith("[")]
for i, el in enumerate(fp_elements):
if el[0].startswith("["):
fp_elements[i] = (re.sub(r"\[(.*?)\]", rf"""{definition_style}[\1]""", fp_elements[i][0]), fp_elements[i][1])
else:
fp_elements[i] = (
f"{letter_style}{fp_elements[i][0]}",
f"{letter_style}{fp_elements[i][1]}",
)
fp = re.findall(r"Prima lettura: (.*)", o)[0]
s_elements = re.findall(r"(\d+) = (.*)", o)
s = re.findall(r"Soluzione: (.*)", o)[0]
for d in definitions:
rebus_letters = rebus_letters.replace("<<<>>>", d, 1)
rebus_highlighted = re.sub(r"\[(.*?)\]", rf"""{definition_style}[\1]""", rebus_letters)
return {
"rebus": rebus_highlighted,
"key": key_highlighted,
"key_split": key_split,
"fp_elements": fp_elements,
"fp": fp,
"fp_empty": fp_empty,
"s_elements": s_elements,
"s": s
}
#tokenizer = AutoTokenizer.from_pretrained("gsarti/phi3-mini-rebus-solver-fp16")
#model = AutoModelForCausalLM.from_pretrained("gsarti/phi3-mini-rebus-solver-fp16")
@spaces.GPU
def solve_verbalized_rebus(example, history):
input = template.format(input=example)
#inputs = tokenizer(input, return_tensors="pt")["input_ids"]
#outputs = model.generate(input_ids = inputs, max_new_tokens = 500, use_cache = True)
#model_generations = tokenizer.batch_decode(outputs)
#return model_generations[0]
return input
#demo = gr.ChatInterface(fn=solve_verbalized_rebus, examples=["Rebus: [Materiale espulso dai vulcani] R O [Strumento del calzolaio] [Si trovano ai lati del bacino] C I [Si ingrassano con la polenta] E I N [Contiene scorte di cibi] B [Isola in francese]\nChiave risolutiva: 1 ' 5 6 5 3 3 1 14"], title="Verbalized Rebus Solver")
#demo.launch()
with gr.Blocks(css=custom_css) as demo:
lang = gr.Dropdown([("English", "en"), ("Italian", "it")], value="it", label="Select language:", interactive=True)
with Translate("translations.yaml", lang, placeholder_langs=["en", "it"]):
gr.Markdown(gettext("Title"))
gr.Markdown(gettext("Intro"))
with gr.Tab(gettext("GuessingGame")):
with gr.Row():
with gr.Column():
example_id = gr.Number(1, label=gettext("CurrentExample"), minimum=1, maximum=2000, step=1, interactive=True)
with gr.Column():
show_length_hints = gr.Checkbox(False, label=gettext("ShowLengthHints"), interactive=True)
@gr.render(inputs=[example_id, show_length_hints], triggers=[demo.load, example_id.change, show_length_hints.change, lang.change])
def show_example(example_number, show_length_hints):
parsed_rebus = parse_rebus(example_number)
gr.Markdown(gettext("Instructions"))
gr.Markdown(gettext("Rebus") + f"{parsed_rebus['rebus']}"),
gr.Markdown(gettext("Key") + f"{parsed_rebus['key']}")
gr.Markdown("
")
with gr.Row():
answers: list[gr.Textbox] = []
with gr.Column(scale=2):
gr.Markdown(gettext("ProceedToResolution"))
for el_key, el_value in parsed_rebus['fp_elements']:
with gr.Row():
with gr.Column(scale=0.2, min_width=250):
gr.Markdown(f"
{el_key} =
") if el_key.startswith('({len(el_value)} lettere)") with gr.Column(scale=0.2, min_width=150): if el_key.startswith('") with gr.Column(scale=3): key_value = gr.Markdown(parsed_rebus['key_split'], visible=False) fp_empty = gr.Markdown(parsed_rebus['fp_empty'], visible=False) fp = gr.Markdown(gettext("FirstPass") + f"{parsed_rebus['fp_empty']}