File size: 904 Bytes
db55aba
 
 
 
 
 
 
10439c7
 
db55aba
 
 
 
 
10439c7
db55aba
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from Model import BaseModel
import json
import numpy as np
from PIL import Image
from torchvision import transforms as T
import torch 

device = torch.device('cpu')
checkpoint = torch.load('last_checkpoint.pt', map_location = device)
with open('answer.json', 'r', encoding = 'utf8') as f:
    answer_space = json.load(f)
swap_space = {v : k for k, v in answer_space.items()}

model = BaseModel().to(device)
model.load_state_dict(checkpoint['model_state_dict'])

def generate_caption(image, question):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    elif isinstance(image, str):
        image = Image.open(image).convert("RGB")
    transform = T.Compose([T.Resize((224, 224)),T.ToTensor()])
    image = transform(image).unsqueeze(0)
    with torch.no_grad():
        logits = model(image, question)
        idx = torch.argmax(logits)
        return swap_space[idx.item()]