reimagine-it / app.py
Alberto Carmona
Replace a deprecated value
2f08c6b
import torch
import torch.nn as nn
import numpy as np
import json
import captioning.utils.opts as opts
import captioning.models as models
import captioning.utils.misc as utils
import pytorch_lightning as pl
import gradio as gr
from diffusers import LDMTextToImagePipeline
import random
import os
# Checkpoint class
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
def on_keyboard_interrupt(self, trainer, pl_module):
# Save model when keyboard interrupt
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
self._save_model(filepath)
device = 'cpu'
reward = 'clips_grammar'
cfg = f'./configs/phase2/clipRN50_{reward}.yml'
print("Loading cfg from", cfg)
opt = opts.parse_opt(parse=False, cfg=cfg)
import gdown
url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W"
gdown.download_folder(url, quiet=True, use_cookies=False, output="save/")
url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp"
gdown.download(url, quiet=True, use_cookies=False, output="data/")
dict_json = json.load(open('./data/cocotalk.json'))
print(dict_json.keys())
ix_to_word = dict_json['ix_to_word']
vocab_size = len(ix_to_word)
print('vocab size:', vocab_size)
seq_length = 1
opt.vocab_size = vocab_size
opt.seq_length = seq_length
opt.batch_size = 1
opt.vocab = ix_to_word
model = models.setup(opt)
del opt.vocab
ckpt_path = opt.checkpoint_path + '-last.ckpt'
print("Loading checkpoint from", ckpt_path)
raw_state_dict = torch.load(
ckpt_path,
map_location=device)
strict = True
state_dict = raw_state_dict['state_dict']
if '_vocab' in state_dict:
model.vocab = utils.deserialize(state_dict['_vocab'])
del state_dict['_vocab']
elif strict:
raise KeyError
if '_opt' in state_dict:
saved_model_opt = utils.deserialize(state_dict['_opt'])
del state_dict['_opt']
# Make sure the saved opt is compatible with the curren topt
need_be_same = ["caption_model",
"rnn_type", "rnn_size", "num_layers"]
for checkme in need_be_same:
if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
getattr(opt, checkme) in ['updown', 'topdown']:
continue
assert getattr(saved_model_opt, checkme) == getattr(
opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
elif strict:
raise KeyError
res = model.load_state_dict(state_dict, strict)
print(res)
model = model.to(device)
model.eval();
import clip
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor
from PIL import Image
from timm.models.vision_transformer import resize_pos_embed
clip_model, clip_transform = clip.load("RN50", jit=False, device=device)
preprocess = Compose([
Resize((448, 448), interpolation=Image.Resampling.BICUBIC),
CenterCrop((448, 448)),
ToTensor()
])
image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1)
image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1)
num_patches = 196 #600 * 1000 // 32 // 32
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),)
pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
clip_model.visual.attnpool.positional_embedding = pos_embed
# End below
print('Loading the model: CompVis/ldm-text2im-large-256')
ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0):
print('RUN: generate_image_from_text')
torch.cuda.empty_cache()
generator = torch.manual_seed(seed)
images = ldm_pipeline([prompt], generator=generator, num_inference_steps=steps, eta=0.3, guidance_scale=guidance_scale)["sample"]
return images[0]
def generate_text_from_image(img):
print('RUN: generate_text_from_image')
with torch.no_grad():
image = preprocess(img)
image = torch.tensor(np.stack([image])).to(device)
image -= image_mean
image /= image_std
tmp_att, tmp_fc = clip_model.encode_image(image)
tmp_att = tmp_att[0].permute(1, 2, 0)
tmp_fc = tmp_fc[0]
att_feat = tmp_att
# Inference configurations
eval_kwargs = {}
eval_kwargs.update(vars(opt))
with torch.no_grad():
fc_feats = torch.zeros((1,0)).to(device)
att_feats = att_feat.view(1, 196, 2048).float().to(device)
att_masks = None
# forward the model to also get generated samples for each image
# Only leave one feature for each image, in case duplicate sample
tmp_eval_kwargs = eval_kwargs.copy()
tmp_eval_kwargs.update({'sample_n': 1})
seq, _ = model(
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
seq = seq.data
sents = utils.decode_sequence(model.vocab, seq)
return sents[0]
def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0):
print('RUN: generate_drawing_from_image')
caption = generate_text_from_image(img)
caption = "a kid's drawing of " + caption
print('\tcaption: ' + caption)
gen_image = generate_image_from_text(caption, steps=steps, seed=seed, guidance_scale=guidance_scale)
return gen_image
random_seed = random.randint(0, 2147483647)
gr.Interface(
generate_drawing_from_image,
title='Reimagine the same image but drawn by a kid :)',
inputs=[
gr.Image(type="pil"),
gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1),
gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1),
gr.inputs.Slider(1.0, 20.0, label='Guidance Scale - how much the prompt will influence the results', default=6.0, step=0.1),
],
outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"),
css="#output_image{width: 256px}",
).launch()