YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
Custom BART Model for Text Summarization
This project involves fine-tuning a BART model for text summarization tasks. The model was trained on custom data, and the resulting model is saved locally and uploaded to Hugging Face for further use.
Table of Contents
- Overview
- Installation
- Usage
- Training the Model
- Saving and Uploading the Model
- Generating Summaries
- Contributing
- License
Overview
This project fine-tunes a BART model (facebook/bart-base
) on custom summarization tasks. After training, the model can generate summaries for input text, which can be used for various applications like news article summarization, report generation, etc.
Installation
To get started, ensure you have Python installed (preferably Python 3.8 or above). Install the required dependencies using the following command:
pip install transformers torch huggingface_hub
Usage
Loading the Model and Tokenizer
Ensure you have saved your trained model and tokenizer in the ./custom_bart_model directory. The code snippet below demonstrates how to load the model and generate summaries based on user input.
from transformers import BartTokenizer, BartForConditionalGeneration
import torch
# Load the model and tokenizer
model = "rohansb10/summary"
tokenizer = BartTokenizer.from_pretrained(model)
model = BartForConditionalGeneration.from_pretrained(model)
# Move model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
def generate_summary(input_text):
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=512).to(device)
with torch.no_grad():
summary_ids = model.generate(inputs["input_ids"], max_length=128, num_beams=4, early_stopping=True)
output_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return output_text
user_input = input("Enter your text: ")
output = generate_summary(user_input)
print("\nModel Output:")
print(output)
Training the Model
The training process involves loading the pre-trained BART model and tokenizer, preparing a custom dataset, and training the model using the PyTorch DataLoader. Refer to the train_model() and evaluate_model() functions in the code for the detailed implementation.
Feel free to modify any section to better fit your project’s needs!
- Downloads last month
- 1