from Model import BaseModel import json import numpy as np from PIL import Image from torchvision import transforms as T import torch device = torch.device('cpu') checkpoint = torch.load('last_checkpoint.pt', map_location = device) with open('answer.json', 'r', encoding = 'utf8') as f: answer_space = json.load(f) swap_space = {v : k for k, v in answer_space.items()} model = BaseModel().to(device) model.load_state_dict(checkpoint['model_state_dict']) def generate_caption(image, question): if isinstance(image, np.ndarray): image = Image.fromarray(image) elif isinstance(image, str): image = Image.open(image).convert("RGB") transform = T.Compose([T.Resize((224, 224)),T.ToTensor()]) image = transform(image).unsqueeze(0) with torch.no_grad(): logits = model(image, question) idx = torch.argmax(logits) return swap_space[idx.item()]