File size: 5,757 Bytes
f140a8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14e4f93
f140a8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gc
import numpy as np
import numpy
import torch
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel

from matplotlib import pyplot as plt
from pathlib import Path
from PIL import Image
from torch import autocast
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, logging
import os
from diffusers import StableDiffusionPipeline, DiffusionPipeline

# large or small model

# configurations
height, width       = 512, 512
guidance_scale      = 8
custom_loss_scale   = 200
batch_size          = 1
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"


pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
pipe = DiffusionPipeline.from_pretrained(
    pretrained_model_name_or_path,
    torch_dtype=torch.float32
).to(torch_device)

# Load SD concepts
sdconcepts = ['<morino-hon>', '<space-style>', '<tesla-bot>', '<midjourney-style>', ' <hanfu-anime-style>'] 

pipe.load_textual_inversion("sd-concepts-library/morino-hon-style") 
pipe.load_textual_inversion("sd-concepts-library/space-style") 
pipe.load_textual_inversion("sd-concepts-library/tesla-bot") 
pipe.load_textual_inversion("sd-concepts-library/midjourney-style") 
pipe.load_textual_inversion("sd-concepts-library/hanfu-anime-style")

# define seeds
seed_list = [1, 2, 3, 4, 5]


def custom_loss(images):
    
    # Gradient loss
    gradient_x = torch.abs(images[:, :, :, :-1] - images[:, :, :, 1:]).mean()
    gradient_y = torch.abs(images[:, :, :-1, :] - images[:, :, 1:, :]).mean()
    error = gradient_x + gradient_y
    #Variational loss
    # diff_x = torch.abs(images[:, :, :, :-1] - images[:, :, :, 1:])
    # diff_y = torch.abs(images[:, :, :-1, :] - images[:, :, 1:, :])
    # error = diff_x.mean() + diff_y.mean()

    return error

def latents_to_pil(latents):
    # bath of latents -> list of images
    latents = (1 / 0.18215) * latents
    with torch.no_grad():
        image = pipe.vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1) # 0 to 1
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images
    
def generate_latents(prompts, num_inference_steps, seed_nums, loss_apply=False):
    
    generator = torch.manual_seed(seed_nums)
    
    # scheduler
    scheduler    = LMSDiscreteScheduler(beta_start = 0.00085, beta_end = 0.012, beta_schedule = "scaled_linear", num_train_timesteps = 1000)
    scheduler.set_timesteps(num_inference_steps)
    scheduler.timesteps = scheduler.timesteps.to(torch.float32)

    # text embeddings of the prompt
    text_input = pipe.tokenizer(prompts, padding='max_length', max_length = pipe.tokenizer.model_max_length, truncation= True, return_tensors="pt")
    input_ids = text_input.input_ids.to(torch_device)

    with torch.no_grad():
        text_embeddings = pipe.text_encoder(text_input.input_ids.to(torch_device))[0]

    max_length = text_input.input_ids.shape[-1]
    uncond_input = pipe.tokenizer(
          [""] * batch_size, padding="max_length", max_length= max_length, return_tensors="pt"
    )

    with torch.no_grad():
        uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0]

    text_embeddings = torch.cat([uncond_embeddings,text_embeddings]) # 2,77,768

    # random latent
    latents = torch.randn(
        (batch_size, pipe.unet.config.in_channels, height// 8, width //8),
        generator = generator,
    ) .to(torch.float16)


    latents = latents.to(torch_device)
    latents = latents * scheduler.init_noise_sigma

    for i, t in tqdm(enumerate(scheduler.timesteps), total = len(scheduler.timesteps)):

        latent_model_input = torch.cat([latents] * 2)
        sigma = scheduler.sigmas[i]
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        with torch.no_grad():
            noise_pred = pipe.unet(latent_model_input.to(torch.float32), t, encoder_hidden_states=text_embeddings)["sample"]
            #noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]

        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        if (loss_apply and i%5 == 0): 
            
            latents = latents.detach().requires_grad_()
            #latents_x0 = scheduler.step(noise_pred,t, latents).pred_original_sample # this line does not work
            latents_x0 = latents - sigma * noise_pred

            # use vae to decode the image
            denoised_images = pipe.vae.decode((1/ 0.18215) * latents_x0).sample / 2 + 0.5 # range(0,1)

            loss = custom_loss(denoised_images) * custom_loss_scale
            print(f"Custom gradient loss {loss}")
            
            cond_grad = torch.autograd.grad(loss, latents)[0]
            latents = latents.detach() - cond_grad * sigma**2

        latents = scheduler.step(noise_pred,t, latents).prev_sample
        
    return latents

    
# Function to convert PIL images to NumPy arrays
def pil_to_np(image):
    return np.array(image)
    
def generate_gradio_images(prompt, num_inference_steps, loss_flag = False):
    # after loss is applied
    latents_list = []
    for seed_no, sd in zip(seed_list, sdconcepts):
        prompts = [f'{prompt} {sd}']
        latents = generate_latents(prompts,num_inference_steps, seed_no, loss_apply=loss_flag)
        latents_list.append(latents)
    # show all
    latents_list = torch.vstack(latents_list)
    images = latents_to_pil(latents_list)
    return images