ViVQA / Predict.py
windy2612's picture
Update Predict.py
db55aba verified
raw
history blame
924 Bytes
from Model import BaseModel
import json
import numpy as np
from PIL import Image
from torchvision import transforms as T
import torch
checkpoint = torch.load('last_checkpoint.pt')
with open('answer.json', 'r', encoding = 'utf8') as f:
answer_space = json.load(f)
swap_space = {v : k for k, v in answer_space.items()}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BaseModel().to(device)
model.load_state_dict(checkpoint['model_state_dict'])
def generate_caption(image, question):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str):
image = Image.open(image).convert("RGB")
transform = T.Compose([T.Resize((224, 224)),T.ToTensor()])
image = transform(image).unsqueeze(0)
with torch.no_grad():
logits = model(image, question)
idx = torch.argmax(logits)
return swap_space[idx.item()]