Update Predict.py
Browse files- Predict.py +29 -35
Predict.py
CHANGED
@@ -1,35 +1,29 @@
|
|
1 |
-
from Model import BaseModel
|
2 |
-
import json
|
3 |
-
import numpy as np
|
4 |
-
from PIL import Image
|
5 |
-
from torchvision import transforms as T
|
6 |
-
import torch
|
7 |
-
|
8 |
-
|
9 |
-
checkpoint = torch.load('
|
10 |
-
with open('
|
11 |
-
answer_space = json.load(f)
|
12 |
-
swap_space = {v : k for k, v in answer_space.items()}
|
13 |
-
|
14 |
-
|
15 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
-
model = BaseModel().to(device)
|
17 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
18 |
-
|
19 |
-
def generate_caption(image, question):
|
20 |
-
if isinstance(image, np.ndarray):
|
21 |
-
image = Image.fromarray(image)
|
22 |
-
elif isinstance(image, str):
|
23 |
-
image = Image.open(image).convert("RGB")
|
24 |
-
transform = T.Compose([T.Resize((224, 224)),T.ToTensor()])
|
25 |
-
image = transform(image).unsqueeze(0)
|
26 |
-
with torch.no_grad():
|
27 |
-
logits = model(image, question)
|
28 |
-
idx = torch.argmax(logits)
|
29 |
-
return swap_space[idx.item()]
|
30 |
-
|
31 |
-
if __name__ == "__main__":
|
32 |
-
image = 'Dataset/train/68857.jpg'
|
33 |
-
question = 'màu của chiếc bình là gì'
|
34 |
-
pred = generate_caption(image, question)
|
35 |
-
print(pred)
|
|
|
1 |
+
from Model import BaseModel
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision import transforms as T
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
checkpoint = torch.load('last_checkpoint.pt')
|
10 |
+
with open('answer.json', 'r', encoding = 'utf8') as f:
|
11 |
+
answer_space = json.load(f)
|
12 |
+
swap_space = {v : k for k, v in answer_space.items()}
|
13 |
+
|
14 |
+
|
15 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
model = BaseModel().to(device)
|
17 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
18 |
+
|
19 |
+
def generate_caption(image, question):
|
20 |
+
if isinstance(image, np.ndarray):
|
21 |
+
image = Image.fromarray(image)
|
22 |
+
elif isinstance(image, str):
|
23 |
+
image = Image.open(image).convert("RGB")
|
24 |
+
transform = T.Compose([T.Resize((224, 224)),T.ToTensor()])
|
25 |
+
image = transform(image).unsqueeze(0)
|
26 |
+
with torch.no_grad():
|
27 |
+
logits = model(image, question)
|
28 |
+
idx = torch.argmax(logits)
|
29 |
+
return swap_space[idx.item()]
|
|
|
|
|
|
|
|
|
|
|
|