JerryAnto commited on
Commit
6c29be2
·
1 Parent(s): f754afe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -21
app.py CHANGED
@@ -12,37 +12,42 @@ Original file is located at
12
  from PIL import Image
13
  from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, PreTrainedTokenizerFast
14
  import requests
 
 
 
15
 
16
- model = VisionEncoderDecoderModel.from_pretrained("sachin/vit2distilgpt2")
17
-
18
- vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
19
-
20
- tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2")
21
 
22
- # url = 'https://d2gp644kobdlm6.cloudfront.net/wp-content/uploads/2016/06/bigstock-Shocked-and-surprised-boy-on-t-113798588-300x212.jpg'
 
23
 
24
- # with Image.open(requests.get(url, stream=True).raw) as img:
25
- # pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values
26
 
27
- #encoder_outputs = model.generate(pixel_values.to('cpu'),num_beams=5)
28
 
29
- #generated_sentences = tokenizer.batch_decode(encoder_outputs, skip_special_tokens=True)
 
 
 
 
 
 
 
 
30
 
31
- #generated_sentences
32
 
33
- #naive text processing
34
- #generated_sentences[0].split('.')[0]
35
 
36
- # inference function
37
 
38
- def vit2distilgpt2(img):
39
- pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values
40
- encoder_outputs = generated_ids = model.generate(pixel_values.to('cpu'),num_beams=5)
41
- generated_sentences = tokenizer.batch_decode(encoder_outputs, skip_special_tokens=True)
42
 
43
- return(generated_sentences[0].split('.')[0])
44
 
45
- #!wget https://media.glamour.com/photos/5f171c4fd35176eaedb36823/master/w_2560%2Cc_limit/bike.jpg
46
 
47
  import gradio as gr
48
 
@@ -65,7 +70,7 @@ examples = [
65
  ]
66
 
67
  gr.Interface(
68
- vit2distilgpt2,
69
  inputs,
70
  outputs,
71
  title=title,
 
12
  from PIL import Image
13
  from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, PreTrainedTokenizerFast
14
  import requests
15
+ from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
16
+ import torch
17
+ from PIL import Image
18
 
19
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
20
+ feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
21
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
 
 
22
 
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ model.to(device)
25
 
 
 
26
 
 
27
 
28
+ max_length = 16
29
+ num_beams = 4
30
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
31
+ def predict_step(image_paths):
32
+ images = []
33
+ for image_path in image_paths:
34
+ i_image = Image.open(image_path)
35
+ if i_image.mode != "RGB":
36
+ i_image = i_image.convert(mode="RGB")
37
 
38
+ images.append(i_image)
39
 
40
+ pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
41
+ pixel_values = pixel_values.to(device)
42
 
43
+ output_ids = model.generate(pixel_values, **gen_kwargs)
44
 
45
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
46
+ preds = [pred.strip() for pred in preds]
47
+ return preds
 
48
 
49
+ #predict_step(['/content/drive/MyDrive/caption generator/horses.png'])
50
 
 
51
 
52
  import gradio as gr
53
 
 
70
  ]
71
 
72
  gr.Interface(
73
+ predict_step,
74
  inputs,
75
  outputs,
76
  title=title,