Spaces:
Sleeping
Sleeping
""" | |
* Tag2Text | |
* Written by Xinyu Huang | |
""" | |
import argparse | |
import random | |
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
from models.tag2text import tag2text_caption | |
from PIL import Image | |
parser = argparse.ArgumentParser( | |
description="Tag2Text inferece for tagging and captioning" | |
) | |
parser.add_argument( | |
"--image", | |
metavar="DIR", | |
help="path to dataset", | |
default="images/1641173_2291260800.jpg", | |
) | |
parser.add_argument( | |
"--pretrained", | |
metavar="DIR", | |
help="path to pretrained model", | |
default="pretrained/tag2text_swin_14m.pth", | |
) | |
parser.add_argument( | |
"--image-size", | |
default=384, | |
type=int, | |
metavar="N", | |
help="input image size (default: 448)", | |
) | |
parser.add_argument( | |
"--thre", default=0.68, type=float, metavar="N", help="threshold value" | |
) | |
parser.add_argument( | |
"--specified-tags", default="None", help="User input specified tags" | |
) | |
def inference(image, model, input_tag="None"): | |
with torch.no_grad(): | |
caption, tag_predict = model.generate( | |
image, tag_input=None, max_length=50, return_tag_predict=True | |
) | |
if input_tag == "" or input_tag == "none" or input_tag == "None": | |
return tag_predict[0], None, caption[0] | |
# If user input specified tags: | |
else: | |
input_tag_list = [] | |
input_tag_list.append(input_tag.replace(",", " | ")) | |
with torch.no_grad(): | |
caption, input_tag = model.generate( | |
image, tag_input=input_tag_list, max_length=50, return_tag_predict=True | |
) | |
return tag_predict[0], input_tag[0], caption[0] | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
) | |
transform = transforms.Compose( | |
[ | |
transforms.Resize((args.image_size, args.image_size)), | |
transforms.ToTensor(), | |
normalize, | |
] | |
) | |
# delete some tags that may disturb captioning | |
# 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one" | |
delete_tag_index = [127, 2961, 3351, 3265, 3338, 3355, 3359] | |
#######load model | |
model = tag2text_caption( | |
pretrained=args.pretrained, | |
image_size=args.image_size, | |
vit="swin_b", | |
delete_tag_index=delete_tag_index, | |
) | |
model.threshold = args.thre # threshold for tagging | |
model.eval() | |
model = model.to(device) | |
raw_image = Image.open(args.image).resize((args.image_size, args.image_size)) | |
image = transform(raw_image).unsqueeze(0).to(device) | |
res = inference(image, model, args.specified_tags) | |
print("Model Identified Tags: ", res[0]) | |
print("User Specified Tags: ", res[1]) | |
print("Image Caption: ", res[2]) | |