txtsum / app.py
awitecki's picture
Fix num_beams #2
0ec8d43
raw
history blame
2.09 kB
__all__ = ['learn', 'get_summary', 'intf']
import gradio as gr
import datasets
import pandas as pd
from fastai.text.all import *
from transformers import *
from blurr.text.data.all import *
from blurr.text.modeling.all import *
import nltk
nltk.download('punkt', quiet=True)
raw_data = datasets.load_dataset('cnn_dailymail', '3.0.0', split='train[:1%]')
df = pd.DataFrame(raw_data)
pretrained_model_name = "sshleifer/distilbart-cnn-6-6"
hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=BartForConditionalGeneration)
text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task='summarization')
hf_batch_tfm = Seq2SeqBatchTokenizeTransform(
hf_arch, hf_config, hf_tokenizer, hf_model, max_length=256, max_tgt_length=130, text_gen_kwargs=text_gen_kwargs
)
blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=hf_batch_tfm), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader('article'), get_y=ColReader('highlights'), splitter=RandomSplitter())
dls = dblock.dataloaders(df, bs=2)
seq2seq_metrics = {
'rouge': {
'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
'returns': ["rouge1", "rouge2", "rougeL"]
},
'bertscore': {
'compute_kwargs': { 'lang': 'en' },
'returns': ["precision", "recall", "f1"]
}
}
model = BaseModelWrapper(hf_model)
learn_cbs = [BaseModelCallback]
fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]
learn = Learner(dls,
model,
opt_func=ranger,
loss_func=CrossEntropyLossFlat(),
cbs=learn_cbs,
splitter=partial(blurr_seq2seq_splitter, arch=hf_arch)).to_fp16()
learn.create_opt()
learn.freeze()
def get_summary(text, sequences_num):
return learn.blurr_summarize(text, early_stopping=True, num_beams=int(sequences_num), num_return_sequences=int(sequences_num))[0]
iface = gr.Interface(fn=get_summary, inputs=["text", gr.Number(value=5, label="sequences")], outputs="text")
iface.launch()