ForgeT5 / app.py
Paarth's picture
Update app.py
a5e5b55
raw
history blame
No virus
2.23 kB
import torch
import os
import numpy as np
import gradio as gr
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from datasets.dataset_dict import DatasetDict
from transformers import AdamW, T5ForConditionalGeneration, T5TokenizerFast
import warnings
warnings.simplefilter('ignore')
from summarizer import SummarizerModel
from transformers import AutoTokenizer
MODEL_NAME = 'Salesforce/codet5-base-multi-sum'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = SummarizerModel(MODEL_NAME)
model.load_state_dict(torch.load('codet5-base-1_epoch-val_loss-0.80.pth'))
def summarize(text: str,
tokenizer = tokenizer,
trained_model = model):
"""
Summarizes a given code in text format.
Args:
text: The code in string format that needs to be summarized.
tokenizer: The tokenizer used in the trained T5 model.
trained_model: A SummarizerModel fine-tuned instance of
T5 model family.
"""
text_encoding = tokenizer.encode_plus(
text,
padding = 'max_length',
max_length = 512,
add_special_tokens = True,
return_attention_mask = True,
truncation = True,
return_tensors = 'pt'
)
generated_ids = trained_model.model.generate(
input_ids = text_encoding['input_ids'],
attention_mask = text_encoding['attention_mask'],
max_length = 150,
num_beams = 2,
repetition_penalty = 2.5,
length_penalty = 1.0,
early_stopping = True
)
preds = [tokenizer.decode(gen_id, skip_special_tokens = True,
clean_up_tokenization_spaces=True)
for gen_id in generated_ids]
return "".join(preds)
outputs = gr.outputs.Textbox()
iface = gr.Interface(fn=summarize,
inputs=['text'],
outputs=outputs,
description="Demo for ForgeT5 | Input: A python code | Output: The code summarization")
iface.launch(inline = False)