import os
import gradio as gr
from gradio.components import Textbox, Button, Slider, Checkbox
from AinaTheme import theme
from urllib.error import HTTPError
from rag import RAG
from utils import setup
MAX_NEW_TOKENS = 700
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="True") == "True"
setup()
rag = RAG(
hf_token=os.getenv("HF_TOKEN"),
embeddings_model=os.getenv("EMBEDDINGS"),
model_name=os.getenv("MODEL"),
)
def generate(prompt, model_parameters):
try:
output, context, source = rag.get_response(prompt, model_parameters)
return output, context, source
except HTTPError as err:
if err.code == 400:
gr.Warning(
"The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET."
)
except:
gr.Warning(
"Inference endpoint is not available right now. Please try again later."
)
def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
if input_.strip() == "":
gr.Warning("Not possible to inference an empty input")
return None
model_parameters = {
"NUM_CHUNKS": num_chunks,
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
"top_k": top_k,
"top_p": top_p,
"do_sample": do_sample,
"temperature": temperature
}
output, context, source = generate(input_, model_parameters)
sources_markup = ""
for url in source:
sources_markup += f'{url}
'
return output.strip(), sources_markup, context
def change_interactive(text):
if len(text) == 0:
return gr.update(interactive=True), gr.update(interactive=False)
return gr.update(interactive=True), gr.update(interactive=True)
def clear():
return (
None,
None,
None,
None,
gr.Slider(value=2.0),
gr.Slider(value=MAX_NEW_TOKENS),
gr.Slider(value=1.0),
gr.Slider(value=50),
gr.Slider(value=0.99),
gr.Checkbox(value=False),
gr.Slider(value=0.35),
)
def gradio_app():
with gr.Blocks(theme=theme) as demo:
with gr.Row():
with gr.Column(scale=0.1):
gr.Image("rag_image.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False)
with gr.Column():
gr.Markdown(
"""# Demo de Retrieval-Augmented Generation per documents legals
🔍 **Retrieval-Augmented Generation** (RAG) és una tecnologia de IA que permet interrogar un repositori de documents amb preguntes
en llenguatge natural, i combina tècniques de recuperació d'informació avançades amb models generatius per redactar una resposta
fent servir només la informació existent en els documents del repositori.
🎯 **Objectiu:** Aquest és un primer demostrador amb la normativa vigent publicada al Diari Oficial de la Generalitat de Catalunya, en el
repositori del EADOP (Entitat Autònoma del Diari Oficial i de Publicacions). Aquesta primera versió explora prop de 2000 documents en català,
i genera la resposta fent servir el model Flor6.3b entrenat amb el dataset de QA generativa projecte-aina/RAG_Multilingual.
⚠️ **Advertencies**: Primera versió experimental. El contingut generat per aquest model no està supervisat i pot ser incorrecte.
Si us plau, tingueu-ho en compte quan exploreu aquest recurs.
"""
)
with gr.Row(equal_height=True):
with gr.Column(variant="panel"):
input_ = Textbox(
lines=11,
label="Input",
placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?",
# value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
)
with gr.Row(variant="panel"):
clear_btn = Button(
"Clear",
)
submit_btn = Button("Submit", variant="primary", interactive=False)
with gr.Row(variant="panel"):
with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
num_chunks = Slider(
minimum=1,
maximum=6,
step=1,
value=2,
label="Number of chunks"
)
max_new_tokens = Slider(
minimum=50,
maximum=2000,
step=1,
value=MAX_NEW_TOKENS,
label="Max tokens"
)
repetition_penalty = Slider(
minimum=0.1,
maximum=2.0,
step=0.1,
value=1.0,
label="Repetition penalty"
)
top_k = Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top k"
)
top_p = Slider(
minimum=0.01,
maximum=0.99,
value=0.99,
label="Top p"
)
do_sample = Checkbox(
value=False,
label="Do sample"
)
temperature = Slider(
minimum=0.1,
maximum=1,
value=0.35,
label="Temperature"
)
parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]
with gr.Column(variant="panel"):
output = Textbox(
lines=10,
label="Output",
interactive=False,
show_copy_button=True
)
with gr.Accordion("Sources and context:", open=False):
source_context = gr.Markdown(
label="Sources",
show_label=False,
)
with gr.Accordion("See full context evaluation:", open=False):
context_evaluation = gr.Markdown(
label="Full context",
show_label=False,
# interactive=False,
# autoscroll=False,
# show_copy_button=True
)
input_.change(
fn=change_interactive,
inputs=[input_],
outputs=[clear_btn, submit_btn],
api_name=False,
)
input_.change(
fn=None,
inputs=[input_],
api_name=False,
js="""(i, m) => {
document.getElementById('inputlenght').textContent = i.length + ' '
document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : "";
}""",
)
clear_btn.click(
fn=clear,
inputs=[],
outputs=[input_, output, source_context, context_evaluation] + parameters_compontents,
queue=False,
api_name=False
)
submit_btn.click(
fn=submit_input,
inputs=[input_]+ parameters_compontents,
outputs=[output, source_context, context_evaluation],
api_name="get-results"
)
with gr.Row():
with gr.Column(scale=0.5):
gr.Examples(
examples=[
["""Què és l'EADOP (Entitat Autònoma del Diari Oficial i de Publicacions)?"""],
],
inputs=input_,
outputs=[output, source_context, context_evaluation],
fn=submit_input,
)
gr.Examples(
examples=[
["""Què diu el decret sobre la senyalització de les begudes alcohòliques i el tabac a Catalunya?"""],
],
inputs=input_,
outputs=[output, source_context, context_evaluation],
fn=submit_input,
)
gr.Examples(
examples=[
["""Com es pot inscriure una persona al Registre de catalans i catalanes residents a l'exterior?"""],
],
inputs=input_,
outputs=[output, source_context, context_evaluation],
fn=submit_input,
)
gr.Examples(
examples=[
["""Quina és la finalitat del Servei Meterològic de Catalunya ?"""],
],
inputs=input_,
outputs=[output, source_context, context_evaluation],
fn=submit_input,
)
demo.launch(show_api=True)
if __name__ == "__main__":
gradio_app()