Peter
improve HTML formatting
7dcd8f3
raw
history blame
6.39 kB
import logging
import re
from pathlib import Path
import gradio as gr
import nltk
from cleantext import clean
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
_here = Path(__file__).parent
nltk.download("stopwords") # TODO=find where this requirement originates from
import transformers
transformers.logging.set_verbosity_error()
logging.basicConfig()
def truncate_word_count(text, max_words=512):
"""
truncate_word_count - a helper function for the gradio module
Parameters
----------
text : str, required, the text to be processed
max_words : int, optional, the maximum number of words, default=512
Returns
-------
dict, the text and whether it was truncated
"""
# split on whitespace with regex
words = re.split(r"\s+", text)
processed = {}
if len(words) > max_words:
processed["was_truncated"] = True
processed["truncated_text"] = " ".join(words[:max_words])
else:
processed["was_truncated"] = False
processed["truncated_text"] = text
return processed
def proc_submission(
input_text: str,
model_size: str,
num_beams,
token_batch_length,
length_penalty,
repetition_penalty,
no_repeat_ngram_size,
max_input_length: int = 512,
):
"""
proc_submission - a helper function for the gradio module
Parameters
----------
input_text : str, required, the text to be processed
max_input_length : int, optional, the maximum length of the input text, default=512
Returns
-------
str of HTML, the interactive HTML form for the model
"""
settings = {
"length_penalty": length_penalty,
"repetition_penalty": repetition_penalty,
"no_repeat_ngram_size": no_repeat_ngram_size,
"encoder_no_repeat_ngram_size": 4,
"num_beams": num_beams,
"min_length": 4,
"max_length": int(token_batch_length // 4),
"early_stopping": True,
"do_sample": False,
}
history = {}
clean_text = clean(input_text, lower=False)
max_input_length = 1024 if model_size == "base" else max_input_length
processed = truncate_word_count(clean_text, max_input_length)
if processed["was_truncated"]:
tr_in = processed["truncated_text"]
msg = f"Input text was truncated to {max_input_length} words (based on whitespace)"
logging.warning(msg)
history["WARNING"] = msg
else:
tr_in = input_text
history["was_truncated"] = False
_summaries = summarize_via_tokenbatches(
tr_in,
model_sm if model_size == "base" else model,
tokenizer_sm if model_size == "base" else tokenizer,
batch_length=token_batch_length,
**settings,
)
sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)]
sum_scores = [
f"\n - Section {i}: {round(s['summary_score'],4)}"
for i, s in enumerate(_summaries)
]
history["Summary Text"] = "<br>".join(sum_text)
history["Summary Scores"] = "\n".join(sum_scores)
history["Input"] = tr_in
html = ""
for name, item in history.items():
html += (
f"<h2>{name}:</h2><hr><b>{item}</b><br><br>"
if "summary" not in name.lower()
else f"<h2>{name}:</h2><hr>{item}<br><br>"
)
html += ""
return html
def load_examples(examples_dir="examples"):
"""
load_examples - a helper function for the gradio module to load examples
Returns:
list of str, the examples
"""
src = _here / examples_dir
src.mkdir(exist_ok=True)
examples = [f for f in src.glob("*.txt")]
# load the examples into a list
text_examples = []
for example in examples:
with open(example, "r") as f:
text = f.read()
text_examples.append([text, "large", 2, 512, 0.7, 3.5, 3])
return text_examples
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary")
model_sm, tokenizer_sm = load_model_and_tokenizer("pszemraj/led-base-book-summary")
title = "Long-Form Summarization: LED & BookSum"
description = "A simple demo of how to use a fine-tuned LED model to summarize long-form text. [This model](https://huggingface.co/pszemraj/led-large-book-summary) is a fine-tuned version of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209). The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage."
gr.Interface(
proc_submission,
inputs=[
gr.inputs.Textbox(
lines=10,
label="input text",
placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
),
gr.inputs.Radio(
choices=["base", "large"], label="model size", default="base"
),
gr.inputs.Slider(
minimum=1, maximum=4, label="num_beams", default=1, step=1
),
gr.inputs.Slider(
minimum=512,
maximum=1024,
label="token_batch_length",
default=512,
step=256,
),
gr.inputs.Slider(
minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05
),
gr.inputs.Slider(
minimum=1.0,
maximum=5.0,
label="repetition_penalty",
default=3.5,
step=0.1,
),
gr.inputs.Slider(
minimum=2, maximum=4, label="no_repeat_ngram_size", default=3, step=1
),
],
outputs="html",
examples_per_page=2,
title=title,
description=description,
article="The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co/pszemraj/led-large-book-summary). See the model card for details on usage & a notebook for a tutorial.",
examples=load_examples(),
cache_examples=True,
).launch()