|
from Model import BaseModel
|
|
import json
|
|
import numpy as np
|
|
from PIL import Image
|
|
from torchvision import transforms as T
|
|
import torch
|
|
|
|
|
|
checkpoint = torch.load('Checkpoint/checkpoint.pt')
|
|
with open('Dataset/answer.json', 'r', encoding = 'utf8') as f:
|
|
answer_space = json.load(f)
|
|
swap_space = {v : k for k, v in answer_space.items()}
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
model = BaseModel().to(device)
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
def generate_caption(image, question):
|
|
if isinstance(image, np.ndarray):
|
|
image = Image.fromarray(image)
|
|
elif isinstance(image, str):
|
|
image = Image.open(image).convert("RGB")
|
|
transform = T.Compose([T.Resize((224, 224)),T.ToTensor()])
|
|
image = transform(image).unsqueeze(0)
|
|
with torch.no_grad():
|
|
logits = model(image, question)
|
|
idx = torch.argmax(logits)
|
|
return swap_space[idx.item()]
|
|
|
|
if __name__ == "__main__":
|
|
image = 'Dataset/train/68857.jpg'
|
|
question = 'màu của chiếc bình là gì'
|
|
pred = generate_caption(image, question)
|
|
print(pred) |