|
import gradio as gr |
|
from PIL import Image |
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
|
|
import re |
|
import string |
|
import nltk |
|
nltk.download('stopwords') |
|
nltk.download('wordnet') |
|
nltk.download('omw-1.4') |
|
|
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
|
stopwords = nltk.corpus.stopwords.words('english') |
|
wn = nltk.WordNetLemmatizer() |
|
|
|
def hashtag_generator(image): |
|
raw_image = Image.fromarray(image).convert('RGB') |
|
inputs = processor(raw_image, return_tensors="pt") |
|
out = model.generate( |
|
**inputs, |
|
num_return_sequences=4, |
|
max_length=32, |
|
early_stopping=True, |
|
num_beams=4, |
|
no_repeat_ngram_size=2, |
|
length_penalty=0.8 |
|
) |
|
captions = "" |
|
for i, caption in enumerate(out): |
|
captions = captions +processor.decode(caption, skip_special_tokens=True) + " ," |
|
|
|
text = "".join([word.lower() for word in captions if word not in string.punctuation]) |
|
tokens = re.split('\W+', text) |
|
text = [wn.lemmatize(word) for word in tokens if word not in stopwords] |
|
words = set(text) |
|
hashtags = "" |
|
for hashtag in words: |
|
if len(hashtag) == 0: |
|
pass |
|
else: |
|
hashtags = hashtags + f" ,#{hashtag}" |
|
return hashtags[2:] |
|
|
|
gr.Interface(hashtag_generator, inputs= gr.inputs.Image(), outputs = gr.outputs.Textbox(), live = True).launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|