File size: 2,092 Bytes
7dfe0e4 64af5e0 07c8d3a d704915 03ed2aa 07c8d3a d704915 d65b86c 64af5e0 fbc0638 0ec8d43 64af5e0 fbc0638 64af5e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
__all__ = ['learn', 'get_summary', 'intf']
import gradio as gr
import datasets
import pandas as pd
from fastai.text.all import *
from transformers import *
from blurr.text.data.all import *
from blurr.text.modeling.all import *
import nltk
nltk.download('punkt', quiet=True)
raw_data = datasets.load_dataset('cnn_dailymail', '3.0.0', split='train[:1%]')
df = pd.DataFrame(raw_data)
pretrained_model_name = "sshleifer/distilbart-cnn-6-6"
hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=BartForConditionalGeneration)
text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task='summarization')
hf_batch_tfm = Seq2SeqBatchTokenizeTransform(
hf_arch, hf_config, hf_tokenizer, hf_model, max_length=256, max_tgt_length=130, text_gen_kwargs=text_gen_kwargs
)
blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=hf_batch_tfm), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader('article'), get_y=ColReader('highlights'), splitter=RandomSplitter())
dls = dblock.dataloaders(df, bs=2)
seq2seq_metrics = {
'rouge': {
'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
'returns': ["rouge1", "rouge2", "rougeL"]
},
'bertscore': {
'compute_kwargs': { 'lang': 'en' },
'returns': ["precision", "recall", "f1"]
}
}
model = BaseModelWrapper(hf_model)
learn_cbs = [BaseModelCallback]
fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]
learn = Learner(dls,
model,
opt_func=ranger,
loss_func=CrossEntropyLossFlat(),
cbs=learn_cbs,
splitter=partial(blurr_seq2seq_splitter, arch=hf_arch)).to_fp16()
learn.create_opt()
learn.freeze()
def get_summary(text, sequences_num):
return learn.blurr_summarize(text, early_stopping=True, num_beams=int(sequences_num), num_return_sequences=int(sequences_num))[0]
iface = gr.Interface(fn=get_summary, inputs=["text", gr.Number(value=5, label="sequences")], outputs="text")
iface.launch()
|