Spaces:
Running
Running
import argparse | |
import glob | |
import os | |
import cv2 | |
import numpy | |
import torch | |
from PIL import Image | |
from Model import TRCaptionNet, clip_transform | |
def demo(opt): | |
preprocess = clip_transform(224) | |
model = TRCaptionNet({ | |
"max_length": 35, | |
"clip": "ViT-L/14", | |
"bert": "dbmdz/bert-base-turkish-cased", | |
"proj": True, | |
"proj_num_head": 16 | |
}) | |
device = torch.device(opt.device) | |
model.load_state_dict(torch.load(opt.model_ckpt, map_location=device)["model"], strict=True) | |
model = model.to(device) | |
model.eval() | |
image_paths = glob.glob(os.path.join(opt.input_dir, '*.jpg')) | |
for image_path in sorted(image_paths): | |
img_name = image_path.split('/')[-1] | |
img0 = Image.open(image_path) | |
batch = preprocess(img0).unsqueeze(0).to(device) | |
caption = model.generate(batch, min_length=11, repetition_penalty=1.6)[0] | |
print(f"{img_name} :", caption) | |
orj_img = numpy.array(img0)[:, :, ::-1] | |
h, w, _ = orj_img.shape | |
new_h = 800 | |
new_w = int(new_h * (w / h)) | |
orj_img = cv2.resize(orj_img, (new_w, new_h)) | |
cv2.imshow("image", orj_img) | |
cv2.waitKey(0) | |
return | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Turkish-Image-Captioning!') | |
parser.add_argument('--model-ckpt', type=str, default='./checkpoints/TRCaptionNet_L14_berturk.pth') | |
parser.add_argument('--input-dir', type=str, default='./images/') | |
parser.add_argument('--device', type=str, default='cuda:0') | |
args = parser.parse_args() | |
demo(args) | |