SSahas commited on
Commit
d6b8b9a
·
1 Parent(s): ade5400

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -1
app.py CHANGED
@@ -1,3 +1,50 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/Salesforce/blip-image-captioning-base").launch()
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ from transformers import BlipProcessor, BlipForConditionalGeneration
4
+
5
+ import re
6
+ import string
7
+ import nltk
8
+ nltk.download('stopwords')
9
+ nltk.download('wordnet')
10
+
11
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
13
+
14
+ def hashtag_generator(image):
15
+ raw_image = Image.fromarray(image).convert('RGB')
16
+ inputs = processor(raw_image, return_tensors="pt")
17
+ out = model.generate(
18
+ **inputs,
19
+ num_return_sequences=4,
20
+ max_length=32,
21
+ early_stopping=True,
22
+ num_beams=4,
23
+ no_repeat_ngram_size=2,
24
+ length_penalty=0.8
25
+ )
26
+ captions = ""
27
+ for i, caption in enumerate(out):
28
+ captions = captions +processor.decode(caption, skip_special_tokens=True) + " ,"
29
+
30
+ text = "".join([word.lower() for word in captions if word not in string.punctuation])
31
+ tokens = re.split('\W+', text)
32
+ text = [wn.lemmatize(word) for word in tokens if word not in stopwords]
33
+ words = set(text)
34
+ hashtags = ""
35
+ for hashtag in words:
36
+ if len(hashtag) == 0:
37
+ pass
38
+ else:
39
+ hashtags = hashtags + f" ,#{hashtag}"
40
+ return hashtags[2:]
41
+
42
+ gr.Interface(hashtag_generator, inputs= gr.inputs.Image(), outputs = gr.outputs.Textbox(), live = True).launch()
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50