Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import numpy as np | |
import json | |
import captioning.utils.opts as opts | |
import captioning.models as models | |
import captioning.utils.misc as utils | |
import pytorch_lightning as pl | |
import gradio as gr | |
from diffusers import LDMTextToImagePipeline | |
import random | |
import os | |
# Checkpoint class | |
class ModelCheckpoint(pl.callbacks.ModelCheckpoint): | |
def on_keyboard_interrupt(self, trainer, pl_module): | |
# Save model when keyboard interrupt | |
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') | |
self._save_model(filepath) | |
device = 'cpu' | |
reward = 'clips_grammar' | |
cfg = f'./configs/phase2/clipRN50_{reward}.yml' | |
print("Loading cfg from", cfg) | |
opt = opts.parse_opt(parse=False, cfg=cfg) | |
import gdown | |
url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W" | |
gdown.download_folder(url, quiet=True, use_cookies=False, output="save/") | |
url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp" | |
gdown.download(url, quiet=True, use_cookies=False, output="data/") | |
dict_json = json.load(open('./data/cocotalk.json')) | |
print(dict_json.keys()) | |
ix_to_word = dict_json['ix_to_word'] | |
vocab_size = len(ix_to_word) | |
print('vocab size:', vocab_size) | |
seq_length = 1 | |
opt.vocab_size = vocab_size | |
opt.seq_length = seq_length | |
opt.batch_size = 1 | |
opt.vocab = ix_to_word | |
model = models.setup(opt) | |
del opt.vocab | |
ckpt_path = opt.checkpoint_path + '-last.ckpt' | |
print("Loading checkpoint from", ckpt_path) | |
raw_state_dict = torch.load( | |
ckpt_path, | |
map_location=device) | |
strict = True | |
state_dict = raw_state_dict['state_dict'] | |
if '_vocab' in state_dict: | |
model.vocab = utils.deserialize(state_dict['_vocab']) | |
del state_dict['_vocab'] | |
elif strict: | |
raise KeyError | |
if '_opt' in state_dict: | |
saved_model_opt = utils.deserialize(state_dict['_opt']) | |
del state_dict['_opt'] | |
# Make sure the saved opt is compatible with the curren topt | |
need_be_same = ["caption_model", | |
"rnn_type", "rnn_size", "num_layers"] | |
for checkme in need_be_same: | |
if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ | |
getattr(opt, checkme) in ['updown', 'topdown']: | |
continue | |
assert getattr(saved_model_opt, checkme) == getattr( | |
opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme | |
elif strict: | |
raise KeyError | |
res = model.load_state_dict(state_dict, strict) | |
print(res) | |
model = model.to(device) | |
model.eval(); | |
import clip | |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor | |
from PIL import Image | |
from timm.models.vision_transformer import resize_pos_embed | |
clip_model, clip_transform = clip.load("RN50", jit=False, device=device) | |
preprocess = Compose([ | |
Resize((448, 448), interpolation=Image.Resampling.BICUBIC), | |
CenterCrop((448, 448)), | |
ToTensor() | |
]) | |
image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1) | |
image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1) | |
num_patches = 196 #600 * 1000 // 32 // 32 | |
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),) | |
pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed) | |
clip_model.visual.attnpool.positional_embedding = pos_embed | |
# End below | |
print('Loading the model: CompVis/ldm-text2im-large-256') | |
ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") | |
def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0): | |
print('RUN: generate_image_from_text') | |
torch.cuda.empty_cache() | |
generator = torch.manual_seed(seed) | |
images = ldm_pipeline([prompt], generator=generator, num_inference_steps=steps, eta=0.3, guidance_scale=guidance_scale)["sample"] | |
return images[0] | |
def generate_text_from_image(img): | |
print('RUN: generate_text_from_image') | |
with torch.no_grad(): | |
image = preprocess(img) | |
image = torch.tensor(np.stack([image])).to(device) | |
image -= image_mean | |
image /= image_std | |
tmp_att, tmp_fc = clip_model.encode_image(image) | |
tmp_att = tmp_att[0].permute(1, 2, 0) | |
tmp_fc = tmp_fc[0] | |
att_feat = tmp_att | |
# Inference configurations | |
eval_kwargs = {} | |
eval_kwargs.update(vars(opt)) | |
with torch.no_grad(): | |
fc_feats = torch.zeros((1,0)).to(device) | |
att_feats = att_feat.view(1, 196, 2048).float().to(device) | |
att_masks = None | |
# forward the model to also get generated samples for each image | |
# Only leave one feature for each image, in case duplicate sample | |
tmp_eval_kwargs = eval_kwargs.copy() | |
tmp_eval_kwargs.update({'sample_n': 1}) | |
seq, _ = model( | |
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') | |
seq = seq.data | |
sents = utils.decode_sequence(model.vocab, seq) | |
return sents[0] | |
def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0): | |
print('RUN: generate_drawing_from_image') | |
caption = generate_text_from_image(img) | |
caption = "a kid's drawing of " + caption | |
print('\tcaption: ' + caption) | |
gen_image = generate_image_from_text(caption, steps=steps, seed=seed, guidance_scale=guidance_scale) | |
return gen_image | |
random_seed = random.randint(0, 2147483647) | |
gr.Interface( | |
generate_drawing_from_image, | |
title='Reimagine the same image but drawn by a kid :)', | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1), | |
gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1), | |
gr.inputs.Slider(1.0, 20.0, label='Guidance Scale - how much the prompt will influence the results', default=6.0, step=0.1), | |
], | |
outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"), | |
css="#output_image{width: 256px}", | |
).launch() | |