HashTag / main.py
vama09's picture
Upload 2 files
46612d1
import streamlit as st
from PIL import Image
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import itertools
from nltk.corpus import stopwords
import nltk
import easyocr
import numpy as np
nltk.download('stopwords')
# load the model and tokenizer
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
reader = easyocr.Reader(['en'])
# set up Streamlit app
st.set_page_config(layout='wide', page_title='Image Hashtag Recommender')
# define function to extract image features and generate hashtags
def generate_hashtags(image_file):
# get image and convert to RGB mode
image = Image.open(image_file).convert('RGB')
# extract image features
pixel_values = feature_extractor(images=[image], return_tensors="pt").pixel_values
output_ids = model.generate(pixel_values)
# decode the model output to text and extract caption words
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
caption_words = [word.lower() for word in output_text.split() if not word.startswith("#")]
# remove stop words from caption words
stop_words = set(stopwords.words('english'))
caption_words = [word for word in caption_words if word not in stop_words]
# use easyocr to extract text from the image
text = reader.readtext(np.array(image))
detected_text = " ".join([item[1] for item in text])
# combine caption words and detected text
all_words = caption_words + detected_text.split()
# generate combinations of words for hashtags
hashtags = []
for n in range(1, 4):
word_combinations = list(itertools.combinations(all_words, n))
for combination in word_combinations:
hashtag = "#" + "".join(combination)
hashtags.append(hashtag)
# return top 10 hashtags by frequency
top_hashtags = [tag for tag in sorted(set(hashtags), key=hashtags.count, reverse=True) if tag != "#"]
return top_hashtags[:10]
# display the Streamlit app
st.title("Image Hashtag Recommender")
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
# if the user has submitted an image, generate hashtags
if image_file is not None:
try:
hashtags = generate_hashtags(image_file)
if len(hashtags) > 0:
st.write("Top 10 hashtags for this image:")
for tag in hashtags:
st.write(tag)
else:
st.write("No hashtags found for this image.")
except Exception as e:
st.write(f"Error: {e}")