Spaces:
Sleeping
Sleeping
File size: 5,109 Bytes
fe0e9af 904400a fe0e9af 66e7228 fe0e9af 904400a fe0e9af 66e7228 fe0e9af 9b3e02d fe0e9af 66e7228 fe0e9af 66e7228 fe0e9af 504e8b4 fe0e9af 66e7228 fe0e9af 66e7228 4fc786e 66e7228 4fc786e 66e7228 fe0e9af 504e8b4 fe0e9af 50085ad 504e8b4 66e7228 fe0e9af 9b3e02d 4fc786e 66e7228 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import logging
from pathlib import Path
import os
import re
import gradio as gr
import nltk
import torch
from cleantext import clean
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
_here = Path(__file__).parent
nltk.download("stopwords") # TODO=find where this requirement originates from
import transformers
transformers.logging.set_verbosity_error()
logging.basicConfig()
def truncate_word_count(text, max_words=512):
"""
truncate_word_count - a helper function for the gradio module
Parameters
----------
text : str, required, the text to be processed
max_words : int, optional, the maximum number of words, default=512
Returns
-------
dict, the text and whether it was truncated
"""
# split on whitespace with regex
words = re.split(r"\s+", text)
processed = {}
if len(words) > max_words:
processed["was_truncated"] = True
processed["truncated_text"] = " ".join(words[:max_words])
else:
processed["was_truncated"] = False
processed["truncated_text"] = text
return processed
def proc_submission(
input_text: str,
num_beams,
token_batch_length,
length_penalty,
repetition_penalty,
no_repeat_ngram_size,
max_input_length: int = 512,
):
"""
proc_submission - a helper function for the gradio module
Parameters
----------
input_text : str, required, the text to be processed
max_input_length : int, optional, the maximum length of the input text, default=512
Returns
-------
str of HTML, the interactive HTML form for the model
"""
settings = {
"length_penalty": length_penalty,
"repetition_penalty": repetition_penalty,
"no_repeat_ngram_size": no_repeat_ngram_size,
"encoder_no_repeat_ngram_size": 4,
"num_beams": num_beams,
}
history = {}
clean_text = clean(input_text, lower=False)
processed = truncate_word_count(clean_text, max_input_length)
if processed["was_truncated"]:
history["input_text"] = processed["truncated_text"]
history["was_truncated"] = True
msg = f"Input text was truncated to {max_input_length} characters."
logging.warning(msg)
history["WARNING"] = msg
else:
history["input_text"] = input_text
history["was_truncated"] = False
_summaries = summarize_via_tokenbatches(
history["input_text"],
model,
tokenizer,
batch_length=token_batch_length,
**settings,
)
sum_text = [s["summary"][0] for s in _summaries]
sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in _summaries]
history["Input"] = input_text
history["Summary Text"] = "\n\t".join(sum_text)
history["Summary Scores"] = "\n".join(sum_scores)
html = ""
for name, item in history.items():
html += (
f"<h2>{name}:</h2><hr><b>{item}</b><br><br>"
if "summary" not in name.lower()
else f"<h2>{name}:</h2><hr><b>{item}</b>"
)
html += ""
return html
def load_examples(examples_dir="examples"):
src = _here / examples_dir
src.mkdir(exist_ok=True)
examples = [f for f in src.glob("*.txt")]
# load the examples into a list
text_examples = []
for example in examples:
with open(example, "r") as f:
text = f.read()
text_examples.append([text, 2, 1024, 0.7, 3.5, 3])
return text_examples
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary")
title = "Long-form Summarization: LED & BookSum"
description = (
"This is a simple example of using the LED model to summarize a long-form text. This model is a fine-tuned version of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the booksum dataset. the goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage."
)
gr.Interface(
proc_submission,
inputs=[
gr.inputs.Textbox(lines=10, label="input text"),
gr.inputs.Slider(
minimum=1, maximum=6, label="num_beams", default=2, step=1
),
gr.inputs.Slider(
minimum=512, maximum=2048, label="token_batch_length", default=1024, step=512,
),
gr.inputs.Slider(
minimum=0.5, maximum=1.1, label="length_penalty", default=0.7, step=0.05
),
gr.inputs.Slider(
minimum=1.0,
maximum=5.0,
label="repetition_penalty",
default=3.5,
step=0.1,
),
gr.inputs.Slider(
minimum=2, maximum=4, label="no_repeat_ngram_size", default=3, step=1
),
],
outputs="html",
examples_per_page=4,
title=title,
description=description,
examples=load_examples(),
cache_examples=False,
).launch(enable_queue=True, share=True)
|