lvelho commited on
Commit
6fcf918
1 Parent(s): 9c41328
Files changed (1) hide show
  1. app.py +38 -36
app.py CHANGED
@@ -1,36 +1,38 @@
1
- import os
2
- import io
3
- from PIL import Image
4
-
5
- from transformers import pipeline
6
-
7
- get_completion = pipeline("image-to-text",model="Salesforce/blip-image-captioning-base")
8
-
9
- def summarize(input):
10
- output = get_completion(input)
11
- return output[0]['generated_text']
12
-
13
- image_url = "https://free-images.com/sm/9596/dog_animal_greyhound_983023.jpg"
14
-
15
- get_completion(image_url)
16
-
17
- import gradio as gr
18
-
19
- def captioner(image):
20
- result = get_completion(image)
21
- return result[0]['generated_text']
22
-
23
- gr.close_all()
24
-
25
- christmas_dog = "https://free-images.com/sm/9596/dog_animal_greyhound_983023.jpg"
26
- bird = "https://free-images.com/sm/0a00/bird_exotic_bird_green.jpg"
27
- cow = "https://free-images.com/sm/ee7b/cow_animal_cow_head.jpg"
28
-
29
- demo = gr.Interface(fn=captioner,
30
- inputs=[gr.Image(label="Upload image", type="pil", value=christmas_dog)],
31
- outputs=[gr.Textbox(label="Caption")],
32
- title="Image Captioning with BLIP",
33
- description="Caption any image using the BLIP model",
34
- allow_flagging="never",
35
- examples=[christmas_dog, bird, cow])
36
- demo.launch(share=True)
 
 
 
1
+ import torch
2
+ import re
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
+
6
+ device='cpu'
7
+ encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
8
+ decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
9
+ model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
10
+ feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
11
+ tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
12
+ model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
13
+
14
+ def predict(image,max_length=64, num_beams=4):
15
+ image = image.convert('RGB')
16
+ image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
17
+ clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
18
+ caption_ids = model.generate(image, max_length = max_length)[0]
19
+ caption_text = clean_text(tokenizer.decode(caption_ids))
20
+ return caption_text
21
+
22
+ input = gr.inputs.Image(label="Upload any Image", type = 'pil', optional=True)
23
+ output = gr.outputs.Textbox(type="auto",label="Captions")
24
+ examples = [f"example{i}.jpg" for i in range(1,7)]
25
+
26
+ title = "Image Captioning "
27
+ description = "Made by : shreyasdixit.tech"
28
+ interface = gr.Interface(
29
+
30
+ fn=predict,
31
+ description=description,
32
+ inputs = input,
33
+ theme="grass",
34
+ outputs=output,
35
+ examples = examples,
36
+ title=title,
37
+ )
38
+ interface.launch(debug=True)