Peter
:bug: set min beams to 2
b2df366
raw
history blame
6.39 kB
import logging
import re
from pathlib import Path
import gradio as gr
import nltk
from cleantext import clean
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
_here = Path(__file__).parent
nltk.download("stopwords") # TODO=find where this requirement originates from
import transformers
transformers.logging.set_verbosity_error()
logging.basicConfig()
def truncate_word_count(text, max_words=512):
"""
truncate_word_count - a helper function for the gradio module
Parameters
----------
text : str, required, the text to be processed
max_words : int, optional, the maximum number of words, default=512
Returns
-------
dict, the text and whether it was truncated
"""
# split on whitespace with regex
words = re.split(r"\s+", text)
processed = {}
if len(words) > max_words:
processed["was_truncated"] = True
processed["truncated_text"] = " ".join(words[:max_words])
else:
processed["was_truncated"] = False
processed["truncated_text"] = text
return processed
def proc_submission(
input_text: str,
model_size: str,
num_beams,
token_batch_length,
length_penalty,
repetition_penalty,
no_repeat_ngram_size,
max_input_length: int = 512,
):
"""
proc_submission - a helper function for the gradio module
Parameters
----------
input_text : str, required, the text to be processed
max_input_length : int, optional, the maximum length of the input text, default=512
Returns
-------
str of HTML, the interactive HTML form for the model
"""
settings = {
"length_penalty": length_penalty,
"repetition_penalty": repetition_penalty,
"no_repeat_ngram_size": no_repeat_ngram_size,
"encoder_no_repeat_ngram_size": 4,
"num_beams": num_beams,
"min_length": 4,
"max_length": int(token_batch_length // 4),
"early_stopping": True,
"do_sample": False,
}
history = {}
clean_text = clean(input_text, lower=False)
max_input_length = 1024 if model_size == "base" else max_input_length
processed = truncate_word_count(clean_text, max_input_length)
if processed["was_truncated"]:
tr_in = processed["truncated_text"]
msg = f"Input text was truncated to {max_input_length} words (based on whitespace)"
logging.warning(msg)
history["WARNING"] = msg
else:
tr_in = input_text
history["was_truncated"] = False
_summaries = summarize_via_tokenbatches(
tr_in,
model_sm if model_size == "base" else model,
tokenizer_sm if model_size == "base" else tokenizer,
batch_length=token_batch_length,
**settings,
)
sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)]
sum_scores = [
f"\n - Section {i}: {round(s['summary_score'],4)}"
for i, s in enumerate(_summaries)
]
history["Summary Text"] = "<br>".join(sum_text)
history["Summary Scores"] = "\n".join(sum_scores)
history["Input"] = tr_in
html = ""
for name, item in history.items():
html += (
f"<h2>{name}:</h2><hr><b>{item}</b><br><br>"
if "summary" not in name.lower()
else f"<h2>{name}:</h2><hr>{item}<br><br>"
)
html += ""
return html
def load_examples(examples_dir="examples"):
"""
load_examples - a helper function for the gradio module to load examples
Returns:
list of str, the examples
"""
src = _here / examples_dir
src.mkdir(exist_ok=True)
examples = [f for f in src.glob("*.txt")]
# load the examples into a list
text_examples = []
for example in examples:
with open(example, "r") as f:
text = f.read()
text_examples.append([text, "large", 2, 512, 0.7, 3.5, 3])
return text_examples
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary")
model_sm, tokenizer_sm = load_model_and_tokenizer("pszemraj/led-base-book-summary")
title = "Long-Form Summarization: LED & BookSum"
description = "A simple demo of how to use a fine-tuned LED model to summarize long-form text. [This model](https://huggingface.co/pszemraj/led-large-book-summary) is a fine-tuned version of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209). The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage."
gr.Interface(
proc_submission,
inputs=[
gr.inputs.Textbox(
lines=10,
label="input text",
placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
),
gr.inputs.Radio(
choices=["base", "large"], label="model size", default="base"
),
gr.inputs.Slider(
minimum=2, maximum=4, label="num_beams", default=2, step=1
),
gr.inputs.Slider(
minimum=512,
maximum=1024,
label="token_batch_length",
default=512,
step=256,
),
gr.inputs.Slider(
minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05
),
gr.inputs.Slider(
minimum=1.0,
maximum=5.0,
label="repetition_penalty",
default=3.5,
step=0.1,
),
gr.inputs.Slider(
minimum=2, maximum=4, label="no_repeat_ngram_size", default=3, step=1
),
],
outputs="html",
examples_per_page=2,
title=title,
description=description,
article="The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co/pszemraj/led-large-book-summary). See the model card for details on usage & a notebook for a tutorial.",
examples=load_examples(),
cache_examples=True,
).launch()