File size: 4,549 Bytes
054082d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import os
from torch.optim import AdamW
from diffusers import StableDiffusionPipeline
from torch import autocast, inference_mode
import torch
import numpy as np

from scheduling_ddim import DDIMScheduler


device = 'cuda'
# don't forget to add your token or comment if already logged in
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", 
                                               scheduler=DDIMScheduler(beta_end=0.012,
                                                                       beta_schedule="scaled_linear",
                                                                       beta_start=0.00085),
                                               use_auth_token="").to(device)
_ = pipe.vae.requires_grad_(False)
_ = pipe.text_encoder.requires_grad_(False)
_ = pipe.unet.requires_grad_(False)
                                               
def preprocess(image):
    w, h = image.size
    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    image = image.resize((w, h), resample=Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0

def im2latent(pipe, im, generator):
    init_image = preprocess(im).to(pipe.device)
    init_latent_dist = pipe.vae.encode(init_image).latent_dist
    init_latents = init_latent_dist.sample(generator=generator)
    
    return init_latents * 0.18215


def image_mod(init_image, source_prompt, prompt, scale, steps, seed):
    # fix seed
    g = torch.Generator(device=pipe.device).manual_seed(84)
    
    image_latents = im2latent(pipe, init_image, g)
    pipe.scheduler.set_timesteps(steps)
    # use text describing an image
    # source_prompt = "a photo of a woman"
    context = pipe._encode_prompt(source_prompt, pipe.device, 1, False, "")
    
    decoded_latents = image_latents.clone()
    with autocast(device), inference_mode():
        # we are pivoting timesteps as we are moving in opposite direction
        timesteps = pipe.scheduler.timesteps.flip(0)
        # this would be our targets for pivoting
        init_trajectory = torch.empty(len(timesteps), *decoded_latents.size()[1:], device=decoded_latents.device, dtype=decoded_latents.dtype)
        for i, t in enumerate(tqdm(timesteps)):
            init_trajectory[i:i+1] = decoded_latents
            noise_pred = pipe.unet(decoded_latents, t, encoder_hidden_states=context).sample
            decoded_latents = pipe.scheduler.reverse_step(noise_pred, t, decoded_latents).next_sample
            
     # we would need to flip trajectory values for pivoting in right direction
    init_trajectory = init_trajectory.cpu().flip(0)
    
    latents = decoded_latents.clone()
    context_uncond = pipe._encode_prompt("", pipe.device, 1, False, "")
    # we will be optimizing uncond text embedding
    context_uncond.requires_grad_(True)
    
    # use same text
    # prompt = "a photo of a woman"
    context_cond = pipe._encode_prompt(prompt, pipe.device, 1, False, "")
    
    # default lr works
    opt = AdamW([context_uncond])
    
    # concat latents for classifier-free guidance
    latents = torch.cat([latents, latents])
    latents.requires_grad_(True)
    context = torch.cat((context_uncond, context_cond))
    
    with autocast(device):
        for i, t in enumerate(tqdm(pipe.scheduler.timesteps)):
            latents = pipe.scheduler.scale_model_input(latents, t)
            uncond, cond = pipe.unet(latents, t, encoder_hidden_states=context).sample.chunk(2)
            with torch.enable_grad():
                latents = pipe.scheduler.step(uncond + scale * (cond - uncond), t, latents, generator=g).prev_sample
            
            opt.zero_grad()
            # optimize uncond text emb
            pivot_value = init_trajectory[[i]].to(pipe.device)
            (latents - pivot_value).mean().backward()
            opt.step()
            latents = latents.detach()
    
    images = pipe.decode_latents(latents)
    im = pipe.numpy_to_pil(images)[0]
    return im


demo = gr.Interface(
    image_mod, 
    inputs=[gr.Image(type="pil"), gr.Textbox("a photo of a person"), gr.Textbox("a photo of a person"), gr.Slider(0, 10, 0.5, 0.1), gr.Slider(0, 100, 51, 1), gr.Number(42)], 
    outputs="image",
    flagging_options=["blurry", "incorrect", "other"], examples=[
        os.path.join(os.path.dirname(__file__), "images/00001.jpg"),
        ])

if __name__ == "__main__":
    demo.launch()