CLIP-Caption-Reward / predict.py
akhaliq's picture
akhaliq HF staff
add files
c80917c
raw
history blame
6.03 kB
import os
import numpy as np
import json
import torch
import torch.nn as nn
import clip
import pytorch_lightning as pl
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
from timm.models.vision_transformer import resize_pos_embed
from cog import BasePredictor, Path, Input
import captioning.utils.opts as opts
import captioning.models as models
import captioning.utils.misc as utils
class Predictor(BasePredictor):
def setup(self):
import __main__
__main__.ModelCheckpoint = pl.callbacks.ModelCheckpoint
self.device = torch.device("cuda:0")
self.dict_json = json.load(open("./data/cocotalk.json"))
self.ix_to_word = self.dict_json["ix_to_word"]
self.vocab_size = len(self.ix_to_word)
self.clip_model, self.clip_transform = clip.load(
"RN50", jit=False, device=self.device
)
self.preprocess = Compose(
[
Resize((448, 448), interpolation=Image.BICUBIC),
CenterCrop((448, 448)),
ToTensor(),
]
)
def predict(
self,
image: Path = Input(
description="Input image.",
),
reward: str = Input(
choices=["mle", "cider", "clips", "cider_clips", "clips_grammar"],
default="clips_grammar",
description="Choose a reward criterion.",
),
) -> str:
self.device = torch.device("cuda:0")
self.dict_json = json.load(open("./data/cocotalk.json"))
self.ix_to_word = self.dict_json["ix_to_word"]
self.vocab_size = len(self.ix_to_word)
self.clip_model, self.clip_transform = clip.load(
"RN50", jit=False, device=self.device
)
self.preprocess = Compose(
[
Resize((448, 448), interpolation=Image.BICUBIC),
CenterCrop((448, 448)),
ToTensor(),
]
)
cfg = (
f"configs/phase1/clipRN50_{reward}.yml"
if reward == "mle"
else f"configs/phase2/clipRN50_{reward}.yml"
)
print("Loading cfg from", cfg)
opt = opts.parse_opt(parse=False, cfg=cfg)
print("vocab size:", self.vocab_size)
seq_length = 1
opt.vocab_size = self.vocab_size
opt.seq_length = seq_length
opt.batch_size = 1
opt.vocab = self.ix_to_word
print(opt.caption_model)
model = models.setup(opt)
del opt.vocab
ckpt_path = opt.checkpoint_path + "-last.ckpt"
print("Loading checkpoint from", ckpt_path)
raw_state_dict = torch.load(ckpt_path, map_location=self.device)
strict = True
state_dict = raw_state_dict["state_dict"]
if "_vocab" in state_dict:
model.vocab = utils.deserialize(state_dict["_vocab"])
del state_dict["_vocab"]
elif strict:
raise KeyError
if "_opt" in state_dict:
saved_model_opt = utils.deserialize(state_dict["_opt"])
del state_dict["_opt"]
# Make sure the saved opt is compatible with the curren topt
need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
for checkme in need_be_same:
if (
getattr(saved_model_opt, checkme)
in [
"updown",
"topdown",
]
and getattr(opt, checkme) in ["updown", "topdown"]
):
continue
assert getattr(saved_model_opt, checkme) == getattr(opt, checkme), (
"Command line argument and saved model disagree on '%s' " % checkme
)
elif strict:
raise KeyError
res = model.load_state_dict(state_dict, strict)
print(res)
model = model.to(self.device)
model.eval()
image_mean = (
torch.Tensor([0.48145466, 0.4578275, 0.40821073])
.to(self.device)
.reshape(3, 1, 1)
)
image_std = (
torch.Tensor([0.26862954, 0.26130258, 0.27577711])
.to(self.device)
.reshape(3, 1, 1)
)
num_patches = 196 # 600 * 1000 // 32 // 32
pos_embed = nn.Parameter(
torch.zeros(
1,
num_patches + 1,
self.clip_model.visual.attnpool.positional_embedding.shape[-1],
device=self.device,
),
)
pos_embed.weight = resize_pos_embed(
self.clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed
)
self.clip_model.visual.attnpool.positional_embedding = pos_embed
with torch.no_grad():
image = self.preprocess(Image.open(str(image)).convert("RGB"))
image = torch.tensor(np.stack([image])).to(self.device)
image -= image_mean
image /= image_std
tmp_att, tmp_fc = self.clip_model.encode_image(image)
tmp_att = tmp_att[0].permute(1, 2, 0)
att_feat = tmp_att
# Inference configurations
eval_kwargs = {}
eval_kwargs.update(vars(opt))
with torch.no_grad():
fc_feats = torch.zeros((1, 0)).to(self.device)
att_feats = att_feat.view(1, 196, 2048).float().to(self.device)
att_masks = None
# forward the model to also get generated samples for each image
# Only leave one feature for each image, in case duplicate sample
tmp_eval_kwargs = eval_kwargs.copy()
tmp_eval_kwargs.update({"sample_n": 1})
seq, seq_logprobs = model(
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode="sample"
)
seq = seq.data
sents = utils.decode_sequence(model.vocab, seq)
return sents[0]