ClickbaitFighter / app_zero.py
Iker's picture
Fix
d68525f
import datetime
import os
from collections import OrderedDict
from typing import Any
import gradio as gr
import spaces
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
LogitsProcessorList,
TextStreamer,
)
from cache_system import CacheHandler
from download_url import download_text_and_title
from prompts import (
summarize_clickbait_large_prompt,
summarize_clickbait_short_prompt,
summarize_prompt,
)
from utils import StopAfterTokenIsGenerated
auth_token = os.environ.get("TOKEN") or True
total_runs = 0
tokenizer = AutoTokenizer.from_pretrained("Iker/ClickbaitFighter-10B-pro")
model = AutoModelForCausalLM.from_pretrained(
"Iker/ClickbaitFighter-10B-pro",
torch_dtype=torch.bfloat16,
device_map="auto",
# quantization_config=BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_compute_dtype=torch.bfloat16,
# bnb_4bit_use_double_quant=True,
# ),
# attn_implementation="flash_attention_2",
)
generation_config = GenerationConfig(
max_new_tokens=256, # Los resúmenes son cortos, no necesitamos más tokens
min_new_tokens=1, # No queremos resúmenes vacíos
do_sample=True, # Un poquito mejor que greedy sampling
num_beams=1,
use_cache=True, # Eficiencia
top_k=40,
top_p=0.1,
repetition_penalty=1.1, # Ayuda a evitar que el modelo entre en bucles
encoder_repetition_penalty=1.1, # Favorecemos que el modelo cite el texto original
temperature=0.15, # temperature baja para evitar que el modelo genere texto muy creativo.
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
stop_words = [
"<s>",
"</s>",
"\\n",
"[/INST]",
"[INST]",
"### User:",
"### Assistant:",
"###",
"<start_of_turn>",
"<end_of_turn>",
"<end_of_turn>\\n",
"<eos>",
"<|im_end|>",
]
stop_criteria = LogitsProcessorList(
[
StopAfterTokenIsGenerated(
stops=[
torch.tensor(tokenizer.encode(stop_word, add_special_tokens=False))
for stop_word in stop_words.copy()
],
eos_token_id=tokenizer.eos_token_id,
)
]
)
class HuggingFaceDatasetSaver_custom(gr.HuggingFaceDatasetSaver):
def _deserialize_components(
self,
data_dir,
flag_data: list[Any],
flag_option: str = "",
username: str = "",
) -> tuple[dict[Any, Any], list[Any]]:
"""Deserialize components and return the corresponding row for the flagged sample.
Images/audio are saved to disk as individual files.
"""
# Generate the row corresponding to the flagged sample
features = OrderedDict()
row = []
for component, sample in zip(self.components, flag_data):
label = component.label or ""
features[label] = {"dtype": "string", "_type": "Value"}
row.append(sample)
features["flag"] = {"dtype": "string", "_type": "Value"}
features["username"] = {"dtype": "string", "_type": "Value"}
row.append(flag_option)
row.append(username)
return features, row
def finish_generation(text: str) -> str:
return f"{text}\n\n⬇️ Ayuda a mejorar la herramienta marcando si el resumen es correcto o no.⬇️"
@spaces.GPU
def run_model(mode, title, text):
if mode == 0:
prompt = summarize_prompt(title, text)
elif mode == 50:
prompt = summarize_clickbait_large_prompt(title, text)
elif mode == 100:
prompt = summarize_clickbait_short_prompt(title, text)
else:
raise ValueError("Mode not supported")
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer(
[formatted_prompt], return_tensors="pt", add_special_tokens=False
)
streamer = TextStreamer(
tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True
)
model_output = model.generate(
**model_inputs.to(model.device),
streamer=streamer,
generation_config=generation_config,
logits_processor=stop_criteria,
)
# yield streamer # Does not work properly on Zero environment
temp = tokenizer.batch_decode(
model_output[:, model_inputs["input_ids"].shape[-1] :],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)[0]
return temp
def generate_text(
url: str, mode: int, progress=gr.Progress(track_tqdm=False)
) -> (str, str):
global cache_handler
global total_runs
total_runs += 1
print(f"Total runs: {total_runs}. Last run: {datetime.datetime.now()}")
url = url.strip()
if url.startswith("https://twitter.com/") or url.startswith("https://x.com/"):
yield (
"🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.",
"❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌",
"Error",
)
return (
"🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.",
"❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌",
"Error",
)
# 1) Download the article
progress(0, desc="🤖 Accediendo a la noticia")
# First, check if the URL is in the cache
title, text, temp = cache_handler.get_from_cache(url, mode)
if title is not None and text is not None and temp is not None:
temp = finish_generation(temp)
yield title, temp, text
return title, temp, text
else:
try:
title, text, url = download_text_and_title(url)
except Exception as e:
print(e)
title = None
text = None
if title is None or text is None:
yield (
"🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.",
"❌❌❌ Inténtalo de nuevo ❌❌❌",
"Error",
)
return (
"🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.",
"❌❌❌ Inténtalo de nuevo ❌❌❌",
"Error",
)
# Test if the redirected and clean url is in the cache
_, _, temp = cache_handler.get_from_cache(url, mode, second_try=True)
if temp is not None:
temp = finish_generation(temp)
yield title, temp, text
return title, temp, text
progress(0.5, desc="🤖 Leyendo noticia")
try:
temp = run_model(mode, title, text)
except Exception as e:
print(e)
yield (
"🤖 El servidor no se encuentra disponible.",
"❌❌❌ Inténtalo de nuevo más tarde ❌❌❌",
"Error",
)
return (
"🤖 El servidor no se encuentra disponible.",
"❌❌❌ Inténtalo de nuevo más tarde ❌❌❌",
"Error",
)
cache_handler.add_to_cache(
url=url, title=title, text=text, summary_type=mode, summary=temp
)
temp = finish_generation(temp)
yield title, temp, text
hits, misses, cache_len = cache_handler.get_cache_stats()
print(
f"Hits: {hits}, misses: {misses}, cache length: {cache_len}. Percent hits: {round(hits/(hits+misses)*100,2)}%."
)
return title, temp, text
cache_handler = CacheHandler(max_cache_size=1000)
hf_writer = HuggingFaceDatasetSaver_custom(
auth_token, "Iker/Clickbait-News", private=True, separate_dirs=False
)
demo = gr.Interface(
generate_text,
inputs=[
gr.Textbox(
label="🌐 URL de la noticia",
info="Introduce la URL de la noticia que deseas resumir.",
value="https://ikergarcia1996.github.io/Iker-Garcia-Ferrero/",
interactive=True,
),
gr.Slider(
minimum=0,
maximum=100,
step=50,
value=50,
label="🎚️ Nivel de resumen",
info="""¿Hasta qué punto quieres resumir la noticia?
Si solo deseas un resumen, selecciona 0.
Si buscas un resumen y desmontar el clickbait, elige 50.
Para obtener solo la respuesta al clickbait, selecciona 100""",
interactive=True,
),
],
outputs=[
gr.Textbox(
label="📰 Titular de la noticia",
interactive=False,
placeholder="Aquí aparecerá el título de la noticia",
),
gr.Textbox(
label="🗒️ Resumen",
interactive=False,
placeholder="Aquí aparecerá el resumen de la noticia.",
),
gr.Textbox(
label="Noticia completa",
visible=False,
render=False,
interactive=False,
placeholder="Aquí aparecerá el resumen de la noticia.",
),
],
# title="⚔️ Clickbait Fighter! ⚔️",
thumbnail="https://huggingface.co/spaces/Iker/ClickbaitFighter/resolve/main/logo2.png",
theme="JohnSmith9982/small_and_pretty",
description="""
<table>
<tr>
<td style="width:100%"><img src="https://huggingface.co/spaces/Iker/ClickbaitFighter/resolve/main/head.png" align="right" width="100%"> </td>
</tr>
</table>
<p align="justify">Esta Inteligencia Artificial es capaz de generar un resumen de una sola frase que revela la verdad detrás de un titular sensacionalista o clickbait. Solo tienes que introducir la URL de la noticia. La IA accederá a la noticia, la leerá y en cuestión de segundos generará un resumen de una sola frase que revele la verdad detrás del titular.</p>
🎚 Ajusta el nivel de resumen con el control deslizante. Cuanto maś alto, más corto será el resumen.
⌚ La IA se encuentra corriendo en un hardware bastante modesto, debería tardar menos de 30 segundos en generar el resumen, pero si muchos usuarios usan la app a la vez, tendrás que esperar tu turno.
💸 Este es un projecto sin ánimo de lucro, no se genera ningún tipo de ingreso con esta app. Los datos, la IA y el código se publicarán para su uso en la investigación académica. No puedes usar esta app para ningún uso comercial.
🧪 El modelo se encuentra en fase de desarrollo, si quieres ayudar a mejorarlo puedes usar los botones 👍 y 👎 para valorar el resumen. ¡Gracias por tu ayuda!""",
article="Esta Inteligencia Artificial ha sido generada por Iker García-Ferrero. Puedes saber más sobre mi trabajo en mi [página web](https://ikergarcia1996.github.io/Iker-Garcia-Ferrero/) o mi perfil de [X](https://twitter.com/iker_garciaf). Puedes ponerte en contacto conmigo a través de correo electrónico (ver web) y X.",
cache_examples=False,
allow_flagging="manual",
flagging_options=[("👍", "correct"), ("👎", "incorrect")],
flagging_callback=hf_writer,
concurrency_limit=20,
)
demo.queue(max_size=None)
demo.launch(share=False)