File size: 904 Bytes
db55aba 10439c7 db55aba 10439c7 db55aba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
from Model import BaseModel
import json
import numpy as np
from PIL import Image
from torchvision import transforms as T
import torch
device = torch.device('cpu')
checkpoint = torch.load('last_checkpoint.pt', map_location = device)
with open('answer.json', 'r', encoding = 'utf8') as f:
answer_space = json.load(f)
swap_space = {v : k for k, v in answer_space.items()}
model = BaseModel().to(device)
model.load_state_dict(checkpoint['model_state_dict'])
def generate_caption(image, question):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str):
image = Image.open(image).convert("RGB")
transform = T.Compose([T.Resize((224, 224)),T.ToTensor()])
image = transform(image).unsqueeze(0)
with torch.no_grad():
logits = model(image, question)
idx = torch.argmax(logits)
return swap_space[idx.item()]
|