summary / README.md
rohansb10's picture
Update README.md
624b3f9 verified
# Custom BART Model for Text Summarization
This project involves fine-tuning a BART model for text summarization tasks. The model was trained on custom data, and the resulting model is saved locally and uploaded to Hugging Face for further use.
## Table of Contents
- [Overview](#overview)
- [Installation](#installation)
- [Usage](#usage)
- [Training the Model](#training-the-model)
- [Saving and Uploading the Model](#saving-and-uploading-the-model)
- [Generating Summaries](#generating-summaries)
- [Contributing](#contributing)
- [License](#license)
## Overview
This project fine-tunes a BART model (`facebook/bart-base`) on custom summarization tasks. After training, the model can generate summaries for input text, which can be used for various applications like news article summarization, report generation, etc.
## Installation
To get started, ensure you have Python installed (preferably Python 3.8 or above). Install the required dependencies using the following command:
```bash
pip install transformers torch huggingface_hub
Usage
Loading the Model and Tokenizer
Ensure you have saved your trained model and tokenizer in the ./custom_bart_model directory. The code snippet below demonstrates how to load the model and generate summaries based on user input.
from transformers import BartTokenizer, BartForConditionalGeneration
import torch
# Load the model and tokenizer
model = "rohansb10/summary"
tokenizer = BartTokenizer.from_pretrained(model)
model = BartForConditionalGeneration.from_pretrained(model)
# Move model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
def generate_summary(input_text):
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=512).to(device)
with torch.no_grad():
summary_ids = model.generate(inputs["input_ids"], max_length=128, num_beams=4, early_stopping=True)
output_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return output_text
user_input = input("Enter your text: ")
output = generate_summary(user_input)
print("\nModel Output:")
print(output)
Training the Model
The training process involves loading the pre-trained BART model and tokenizer, preparing a custom dataset, and training the model using the PyTorch DataLoader. Refer to the train_model() and evaluate_model() functions in the code for the detailed implementation.
Feel free to modify any section to better fit your project’s needs!