windy2612 commited on
Commit
db55aba
1 Parent(s): 9868b55

Update Predict.py

Browse files
Files changed (1) hide show
  1. Predict.py +29 -35
Predict.py CHANGED
@@ -1,35 +1,29 @@
1
- from Model import BaseModel
2
- import json
3
- import numpy as np
4
- from PIL import Image
5
- from torchvision import transforms as T
6
- import torch
7
-
8
-
9
- checkpoint = torch.load('Checkpoint/checkpoint.pt')
10
- with open('Dataset/answer.json', 'r', encoding = 'utf8') as f:
11
- answer_space = json.load(f)
12
- swap_space = {v : k for k, v in answer_space.items()}
13
-
14
-
15
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
- model = BaseModel().to(device)
17
- model.load_state_dict(checkpoint['model_state_dict'])
18
-
19
- def generate_caption(image, question):
20
- if isinstance(image, np.ndarray):
21
- image = Image.fromarray(image)
22
- elif isinstance(image, str):
23
- image = Image.open(image).convert("RGB")
24
- transform = T.Compose([T.Resize((224, 224)),T.ToTensor()])
25
- image = transform(image).unsqueeze(0)
26
- with torch.no_grad():
27
- logits = model(image, question)
28
- idx = torch.argmax(logits)
29
- return swap_space[idx.item()]
30
-
31
- if __name__ == "__main__":
32
- image = 'Dataset/train/68857.jpg'
33
- question = 'màu của chiếc bình là gì'
34
- pred = generate_caption(image, question)
35
- print(pred)
 
1
+ from Model import BaseModel
2
+ import json
3
+ import numpy as np
4
+ from PIL import Image
5
+ from torchvision import transforms as T
6
+ import torch
7
+
8
+
9
+ checkpoint = torch.load('last_checkpoint.pt')
10
+ with open('answer.json', 'r', encoding = 'utf8') as f:
11
+ answer_space = json.load(f)
12
+ swap_space = {v : k for k, v in answer_space.items()}
13
+
14
+
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ model = BaseModel().to(device)
17
+ model.load_state_dict(checkpoint['model_state_dict'])
18
+
19
+ def generate_caption(image, question):
20
+ if isinstance(image, np.ndarray):
21
+ image = Image.fromarray(image)
22
+ elif isinstance(image, str):
23
+ image = Image.open(image).convert("RGB")
24
+ transform = T.Compose([T.Resize((224, 224)),T.ToTensor()])
25
+ image = transform(image).unsqueeze(0)
26
+ with torch.no_grad():
27
+ logits = model(image, question)
28
+ idx = torch.argmax(logits)
29
+ return swap_space[idx.item()]