Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import logging | |
import re | |
from pathlib import Path | |
import gradio as gr | |
import nltk | |
from cleantext import clean | |
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches | |
_here = Path(__file__).parent | |
nltk.download("stopwords") # TODO=find where this requirement originates from | |
import transformers | |
transformers.logging.set_verbosity_error() | |
logging.basicConfig() | |
def truncate_word_count(text, max_words=512): | |
""" | |
truncate_word_count - a helper function for the gradio module | |
Parameters | |
---------- | |
text : str, required, the text to be processed | |
max_words : int, optional, the maximum number of words, default=512 | |
Returns | |
------- | |
dict, the text and whether it was truncated | |
""" | |
# split on whitespace with regex | |
words = re.split(r"\s+", text) | |
processed = {} | |
if len(words) > max_words: | |
processed["was_truncated"] = True | |
processed["truncated_text"] = " ".join(words[:max_words]) | |
else: | |
processed["was_truncated"] = False | |
processed["truncated_text"] = text | |
return processed | |
def proc_submission( | |
input_text: str, | |
model_size: str, | |
num_beams, | |
token_batch_length, | |
length_penalty, | |
repetition_penalty, | |
no_repeat_ngram_size, | |
max_input_length: int = 512, | |
): | |
""" | |
proc_submission - a helper function for the gradio module | |
Parameters | |
---------- | |
input_text : str, required, the text to be processed | |
max_input_length : int, optional, the maximum length of the input text, default=512 | |
Returns | |
------- | |
str of HTML, the interactive HTML form for the model | |
""" | |
settings = { | |
"length_penalty": length_penalty, | |
"repetition_penalty": repetition_penalty, | |
"no_repeat_ngram_size": no_repeat_ngram_size, | |
"encoder_no_repeat_ngram_size": 4, | |
"num_beams": num_beams, | |
"min_length": 4, | |
"max_length": int(token_batch_length // 4), | |
"early_stopping": True, | |
"do_sample": False, | |
} | |
history = {} | |
clean_text = clean(input_text, lower=False) | |
max_input_length = 1024 if model_size == "base" else max_input_length | |
processed = truncate_word_count(clean_text, max_input_length) | |
if processed["was_truncated"]: | |
tr_in = processed["truncated_text"] | |
msg = f"Input text was truncated to {max_input_length} words (based on whitespace)" | |
logging.warning(msg) | |
history["WARNING"] = msg | |
else: | |
tr_in = input_text | |
history["was_truncated"] = False | |
_summaries = summarize_via_tokenbatches( | |
tr_in, | |
model_sm if model_size == "base" else model, | |
tokenizer_sm if model_size == "base" else tokenizer, | |
batch_length=token_batch_length, | |
**settings, | |
) | |
sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)] | |
sum_scores = [ | |
f"\n - Section {i}: {round(s['summary_score'],4)}" | |
for i, s in enumerate(_summaries) | |
] | |
history["Summary Text"] = "<br>".join(sum_text) | |
history["Summary Scores"] = "\n".join(sum_scores) | |
history["Input"] = tr_in | |
html = "" | |
for name, item in history.items(): | |
html += ( | |
f"<h2>{name}:</h2><hr><b>{item}</b><br><br>" | |
if "summary" not in name.lower() | |
else f"<h2>{name}:</h2><hr>{item}<br><br>" | |
) | |
html += "" | |
return html | |
def load_examples(examples_dir="examples"): | |
""" | |
load_examples - a helper function for the gradio module to load examples | |
Returns: | |
list of str, the examples | |
""" | |
src = _here / examples_dir | |
src.mkdir(exist_ok=True) | |
examples = [f for f in src.glob("*.txt")] | |
# load the examples into a list | |
text_examples = [] | |
for example in examples: | |
with open(example, "r") as f: | |
text = f.read() | |
text_examples.append([text, "large", 2, 512, 0.7, 3.5, 3]) | |
return text_examples | |
if __name__ == "__main__": | |
model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary") | |
model_sm, tokenizer_sm = load_model_and_tokenizer("pszemraj/led-base-book-summary") | |
title = "Long-Form Summarization: LED & BookSum" | |
description = "A simple demo of how to use a fine-tuned LED model to summarize long-form text. [This model](https://huggingface.co/pszemraj/led-large-book-summary) is a fine-tuned version of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209). The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage." | |
gr.Interface( | |
proc_submission, | |
inputs=[ | |
gr.inputs.Textbox( | |
lines=10, | |
label="input text", | |
placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)", | |
), | |
gr.inputs.Radio( | |
choices=["base", "large"], label="model size", default="base" | |
), | |
gr.inputs.Slider( | |
minimum=1, maximum=4, label="num_beams", default=1, step=1 | |
), | |
gr.inputs.Slider( | |
minimum=512, | |
maximum=1024, | |
label="token_batch_length", | |
default=512, | |
step=256, | |
), | |
gr.inputs.Slider( | |
minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05 | |
), | |
gr.inputs.Slider( | |
minimum=1.0, | |
maximum=5.0, | |
label="repetition_penalty", | |
default=3.5, | |
step=0.1, | |
), | |
gr.inputs.Slider( | |
minimum=2, maximum=4, label="no_repeat_ngram_size", default=3, step=1 | |
), | |
], | |
outputs="html", | |
examples_per_page=2, | |
title=title, | |
description=description, | |
article="The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co/pszemraj/led-large-book-summary). See the model card for details on usage & a notebook for a tutorial.", | |
examples=load_examples(), | |
cache_examples=True, | |
).launch() | |