Vageesh1 commited on
Commit
35004f4
1 Parent(s): 4e527a6

Upload engine.py

Browse files
Files changed (1) hide show
  1. engine.py +42 -0
engine.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import json
6
+ from neuralnet.model import SeqToSeq
7
+ import wget
8
+
9
+ url = "https://github.com/Koushik0901/Image-Captioning/releases/download/v1.0/flickr30k.pt"
10
+ # os.system("curl -L https://github.com/Koushik0901/Image-Captioning/releases/download/v1.0/flickr30k.pt")
11
+ filename = wget.download(url)
12
+
13
+ def inference(img_path):
14
+ transform = transforms.Compose(
15
+ [
16
+ transforms.Resize((299, 299)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
19
+ ]
20
+ )
21
+
22
+ vocabulary = json.load(open('./vocab.json'))
23
+
24
+ model_params = {"embed_size":256, "hidden_size":512, "vocab_size": 7666, "num_layers": 3, "device":"cpu"}
25
+ model = SeqToSeq(**model_params)
26
+ checkpoint = torch.load('./flickr30k.pt', map_location = 'cpu')
27
+ model.load_state_dict(checkpoint['state_dict'])
28
+
29
+ img = transform(Image.open(img_path).convert("RGB")).unsqueeze(0)
30
+
31
+ result_caption = []
32
+ model.eval()
33
+
34
+ x = model.encoder(img).unsqueeze(0)
35
+ states = None
36
+
37
+ out_captions = model.caption_image(img, vocabulary['itos'], 50)
38
+ return " ".join(out_captions[1:-1])
39
+
40
+
41
+ if __name__ == '__main__':
42
+ print(inference('./test_examples/dog.png'))