Paarth commited on
Commit
5795073
1 Parent(s): e07a25d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import numpy as np
4
+ import gradio as gr
5
+ import pytorch_lightning as pl
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from datasets import load_dataset
8
+ from pytorch_lightning.callbacks import ModelCheckpoint
9
+ from pytorch_lightning.loggers import TensorBoardLogger
10
+ from datasets.dataset_dict import DatasetDict
11
+ from transformers import AdamW, T5ForConditionalGeneration, T5TokenizerFast
12
+ from tqdm.auto import tqdm
13
+ import warnings
14
+ warnings.simplefilter('ignore')
15
+
16
+ from models.summarizer import SummarizerModel
17
+ from transformers import AutoTokenizer
18
+ MODEL_NAME = 'Salesforce/codet5-base-multi-sum'
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
+ model = SummarizerModel(MODEL_NAME)
21
+
22
+ def summarize(text: str,
23
+ tokenizer = tokenizer,
24
+ trained_model = model):
25
+ """
26
+ Summarizes a given code in text format.
27
+ Args:
28
+ text: The code in string format that needs to be summarized.
29
+ tokenizer: The tokeniszer used in the trained T5 model.
30
+ trained_model: A SummarizerModel fine-tuned instance of
31
+ T5 model family.
32
+ """
33
+ text_encoding = tokenizer.encode_plus(
34
+ text,
35
+ padding = 'max_length',
36
+ max_length = 512,
37
+ add_special_tokens = True,
38
+ return_attention_mask = True,
39
+ truncation = True,
40
+ return_tensors = 'pt'
41
+ )
42
+ generated_ids = trained_model.model.generate(
43
+ input_ids = text_encoding['input_ids'],
44
+ attention_mask = text_encoding['attention_mask'],
45
+ max_length = 150,
46
+ num_beams = 2,
47
+ repetition_penalty = 2.5,
48
+ length_penalty = 1.0,
49
+ early_stopping = True
50
+ )
51
+ preds = [tokenizer.decode(gen_id, skip_special_tokens = True,
52
+ clean_up_tokenization_spaces=True)
53
+ for gen_id in generated_ids]
54
+ return "".join(preds)
55
+
56
+ outputs = gr.outputs.Textbox()
57
+ iface = gr.Interface(fn=summarize,
58
+ inputs=['text'],
59
+ outputs=outputs,
60
+ description="This is the summarization")
61
+ iface.launch()