hashtag / app.py
SSahas's picture
Update app.py
d6b8b9a
raw
history blame
1.48 kB
import gradio as gr
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
import re
import string
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
def hashtag_generator(image):
raw_image = Image.fromarray(image).convert('RGB')
inputs = processor(raw_image, return_tensors="pt")
out = model.generate(
**inputs,
num_return_sequences=4,
max_length=32,
early_stopping=True,
num_beams=4,
no_repeat_ngram_size=2,
length_penalty=0.8
)
captions = ""
for i, caption in enumerate(out):
captions = captions +processor.decode(caption, skip_special_tokens=True) + " ,"
text = "".join([word.lower() for word in captions if word not in string.punctuation])
tokens = re.split('\W+', text)
text = [wn.lemmatize(word) for word in tokens if word not in stopwords]
words = set(text)
hashtags = ""
for hashtag in words:
if len(hashtag) == 0:
pass
else:
hashtags = hashtags + f" ,#{hashtag}"
return hashtags[2:]
gr.Interface(hashtag_generator, inputs= gr.inputs.Image(), outputs = gr.outputs.Textbox(), live = True).launch()