|
from Model import BaseModel |
|
import json |
|
import numpy as np |
|
from PIL import Image |
|
from torchvision import transforms as T |
|
import torch |
|
|
|
device = torch.device('cpu') |
|
checkpoint = torch.load('last_checkpoint.pt', map_location = device) |
|
with open('answer.json', 'r', encoding = 'utf8') as f: |
|
answer_space = json.load(f) |
|
swap_space = {v : k for k, v in answer_space.items()} |
|
|
|
model = BaseModel().to(device) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
def generate_caption(image, question): |
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
elif isinstance(image, str): |
|
image = Image.open(image).convert("RGB") |
|
transform = T.Compose([T.Resize((224, 224)),T.ToTensor()]) |
|
image = transform(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
logits = model(image, question) |
|
idx = torch.argmax(logits) |
|
return swap_space[idx.item()] |
|
|