__all__ = ['learn', 'get_summary', 'intf'] import gradio as gr import datasets import pandas as pd from fastai.text.all import * from transformers import * from blurr.text.data.all import * from blurr.text.modeling.all import * import nltk nltk.download('punkt', quiet=True) raw_data = datasets.load_dataset('cnn_dailymail', '3.0.0', split='train[:1%]') df = pd.DataFrame(raw_data) pretrained_model_name = "sshleifer/distilbart-cnn-6-6" hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=BartForConditionalGeneration) text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task='summarization') hf_batch_tfm = Seq2SeqBatchTokenizeTransform( hf_arch, hf_config, hf_tokenizer, hf_model, max_length=256, max_tgt_length=130, text_gen_kwargs=text_gen_kwargs ) blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=hf_batch_tfm), noop) dblock = DataBlock(blocks=blocks, get_x=ColReader('article'), get_y=ColReader('highlights'), splitter=RandomSplitter()) dls = dblock.dataloaders(df, bs=2) seq2seq_metrics = { 'rouge': { 'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True }, 'returns': ["rouge1", "rouge2", "rougeL"] }, 'bertscore': { 'compute_kwargs': { 'lang': 'en' }, 'returns': ["precision", "recall", "f1"] } } model = BaseModelWrapper(hf_model) learn_cbs = [BaseModelCallback] fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)] learn = Learner(dls, model, opt_func=ranger, loss_func=CrossEntropyLossFlat(), cbs=learn_cbs, splitter=partial(blurr_seq2seq_splitter, arch=hf_arch)).to_fp16() learn.create_opt() learn.freeze() def get_summary(text, sequences_num): return learn.blurr_summarize(text, early_stopping=True, num_beams=int(sequences_num), num_return_sequences=int(sequences_num))[0] iface = gr.Interface(fn=get_summary, inputs=["text", gr.Number(value=5, label="sequences")], outputs="text") iface.launch()