File size: 1,783 Bytes
1dbe050
 
 
 
 
 
 
 
 
5ca9784
1dbe050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2419c09
1dbe050
5ca9784
1dbe050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ca9784
1dbe050
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# -*- coding: utf-8 -*-
"""Listed_Intern.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1QMirZa5iTv4ryooNXeVJXp8K9sHVmuDd
"""

#!pip install transformers

from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from PIL import Image

model= VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer=AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

max_length = 16
num_beams = 4

gen_kwargs = {"max_length":max_length, "num_beams": num_beams}

def predict_step(images):

  pixel_values=feature_extractor(images=images, return_tensors="pt").pixel_values
  pixel_values=pixel_values.to(device)

  output_ids=model.generate(pixel_values, **gen_kwargs)

  preds=tokenizer.batch_decode(output_ids, skip_special_tokens=True)
  preds=[pred.strip() for pred in preds]

  return preds

#predict_step(["/content/Image1.png"])

#!pip install gradio

import gradio as gr

inputs=[
    gr.inputs.Image(type="pil", label="ORiginal Image")
]

outputs=[
    gr.outputs.Textbox(label="Caption")
]

title="Image Captioning"
description="AI based Caption generator"
article = "<a href = 'https://huggingface.co/nlpconnect/vit-gpt2-image-captioning'>Model Repo hugging face model hub</a>"
examples = [
    ["Image1.png","Image2.png"]
]

gr.Interface(
    predict_step,
    inputs,
    outputs,
    title=title,
    description=description,
    article=article,
    examples=examples,
    theme="huggingFace"

).launch(debug=True, enable_queue=True)