cbensimon HF staff commited on
Commit
e227409
·
1 Parent(s): 1a00ee4
Files changed (1) hide show
  1. app.py +223 -0
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import datetime
4
+ import einops
5
+ import gradio as gr
6
+ from gradio_imageslider import ImageSlider
7
+ import numpy as np
8
+ import torch
9
+ import random
10
+ from PIL import Image
11
+ from pathlib import Path
12
+ from torchvision import transforms
13
+ import torch.nn.functional as F
14
+ from torchvision.models import resnet50, ResNet50_Weights
15
+
16
+ from pytorch_lightning import seed_everything
17
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
18
+ from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler
19
+
20
+ from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline
21
+ from myutils.misc import load_dreambooth_lora, rand_name
22
+ from myutils.wavelet_color_fix import wavelet_color_fix
23
+ from annotator.retinaface import RetinaFaceDetection
24
+
25
+ use_pasd_light = False
26
+ face_detector = RetinaFaceDetection()
27
+
28
+ if use_pasd_light:
29
+ from models.pasd_light.unet_2d_condition import UNet2DConditionModel
30
+ from models.pasd_light.controlnet import ControlNetModel
31
+ else:
32
+ from models.pasd.unet_2d_condition import UNet2DConditionModel
33
+ from models.pasd.controlnet import ControlNetModel
34
+
35
+ pretrained_model_path = "checkpoints/stable-diffusion-v1-5"
36
+ ckpt_path = "runs/pasd/checkpoint-100000"
37
+ #dreambooth_lora_path = "checkpoints/personalized_models/toonyou_beta3.safetensors"
38
+ dreambooth_lora_path = "checkpoints/personalized_models/majicmixRealistic_v6.safetensors"
39
+ #dreambooth_lora_path = "checkpoints/personalized_models/Realistic_Vision_V5.1.safetensors"
40
+ weight_dtype = torch.float16
41
+ device = "cuda"
42
+
43
+ scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
44
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
45
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
46
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
47
+ feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor")
48
+ unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet")
49
+ controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet")
50
+ vae.requires_grad_(False)
51
+ text_encoder.requires_grad_(False)
52
+ unet.requires_grad_(False)
53
+ controlnet.requires_grad_(False)
54
+
55
+ unet, vae, text_encoder = load_dreambooth_lora(unet, vae, text_encoder, dreambooth_lora_path)
56
+
57
+ text_encoder.to(device, dtype=weight_dtype)
58
+ vae.to(device, dtype=weight_dtype)
59
+ unet.to(device, dtype=weight_dtype)
60
+ controlnet.to(device, dtype=weight_dtype)
61
+
62
+ validation_pipeline = StableDiffusionControlNetPipeline(
63
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor,
64
+ unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
65
+ )
66
+ #validation_pipeline.enable_vae_tiling()
67
+ validation_pipeline._init_tiled_vae(decoder_tile_size=224)
68
+
69
+ weights = ResNet50_Weights.DEFAULT
70
+ preprocess = weights.transforms()
71
+ resnet = resnet50(weights=weights)
72
+ resnet.eval()
73
+
74
+ def resize_image(image_path, target_height):
75
+ # Open the image file
76
+ with Image.open(image_path) as img:
77
+ # Calculate the ratio to resize the image to the target height
78
+ ratio = target_height / float(img.size[1])
79
+ # Calculate the new width based on the aspect ratio
80
+ new_width = int(float(img.size[0]) * ratio)
81
+ # Resize the image
82
+ resized_img = img.resize((new_width, target_height), Image.LANCZOS)
83
+ # Save the resized image
84
+ #resized_img.save(output_path)
85
+ return resized_img
86
+
87
+ @spaces.GPU(enable_queue=True)
88
+ def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed):
89
+
90
+ #tempo fix for seed equals-1
91
+ if seed == -1:
92
+ seed = 0
93
+
94
+ input_image = resize_image(input_image, 512)
95
+ process_size = 768
96
+ resize_preproc = transforms.Compose([
97
+ transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
98
+ ])
99
+
100
+ # Get the current timestamp
101
+ timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
102
+
103
+ with torch.no_grad():
104
+ seed_everything(seed)
105
+ generator = torch.Generator(device=device)
106
+
107
+ input_image = input_image.convert('RGB')
108
+ batch = preprocess(input_image).unsqueeze(0)
109
+ prediction = resnet(batch).squeeze(0).softmax(0)
110
+ class_id = prediction.argmax().item()
111
+ score = prediction[class_id].item()
112
+ category_name = weights.meta["categories"][class_id]
113
+ if score >= 0.1:
114
+ prompt += f"{category_name}" if prompt=='' else f", {category_name}"
115
+
116
+ prompt = a_prompt if prompt=='' else f"{prompt}, {a_prompt}"
117
+
118
+ ori_width, ori_height = input_image.size
119
+ resize_flag = False
120
+
121
+ rscale = upscale
122
+ input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale))
123
+
124
+ #if min(validation_image.size) < process_size:
125
+ # validation_image = resize_preproc(validation_image)
126
+
127
+ input_image = input_image.resize((input_image.size[0]//8*8, input_image.size[1]//8*8))
128
+ width, height = input_image.size
129
+ resize_flag = True #
130
+
131
+ try:
132
+ image = validation_pipeline(
133
+ None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg,
134
+ negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0,
135
+ ).images[0]
136
+
137
+ if True: #alpha<1.0:
138
+ image = wavelet_color_fix(image, input_image)
139
+
140
+ if resize_flag:
141
+ image = image.resize((ori_width*rscale, ori_height*rscale))
142
+ except Exception as e:
143
+ print(e)
144
+ image = Image.new(mode="RGB", size=(512, 512))
145
+
146
+ # Convert and save the image as JPEG
147
+ image.save(f'result_{timestamp}.jpg', 'JPEG')
148
+
149
+ # Convert and save the image as JPEG
150
+ input_image.save(f'input_{timestamp}.jpg', 'JPEG')
151
+
152
+ return (f"input_{timestamp}.jpg", f"result_{timestamp}.jpg"), f"result_{timestamp}.jpg"
153
+
154
+ title = "Pixel-Aware Stable Diffusion for Real-ISR"
155
+ description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them."
156
+ article = "<a href='https://github.com/yangxy/PASD' target='_blank'>Github Repo Pytorch</a>"
157
+ #examples=[['samples/27d38eeb2dbbe7c9.png'],['samples/629e4da70703193b.png']]
158
+
159
+ css = """
160
+ #col-container{
161
+ margin: 0 auto;
162
+ max-width: 720px;
163
+ }
164
+ #project-links{
165
+ margin: 0 0 12px !important;
166
+ column-gap: 8px;
167
+ display: flex;
168
+ justify-content: center;
169
+ flex-wrap: nowrap;
170
+ flex-direction: row;
171
+ align-items: center;
172
+ }
173
+ """
174
+
175
+ with gr.Blocks(css=css) as demo:
176
+ with gr.Column(elem_id="col-container"):
177
+ gr.HTML(f"""
178
+ <h2 style="text-align: center;">
179
+ PASD Magnify
180
+ </h2>
181
+ <p style="text-align: center;">
182
+ Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
183
+ </p>
184
+ <p id="project-links" align="center">
185
+ <a href='https://github.com/yangxy/PASD'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://huggingface.co/papers/2308.14469'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
186
+ </p>
187
+ <p style="margin:12px auto;display: flex;justify-content: center;">
188
+ <a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space"></a>
189
+ </p>
190
+
191
+ """)
192
+ with gr.Row():
193
+ with gr.Column():
194
+ input_image = gr.Image(type="filepath", sources=["upload"], value="samples/frog.png")
195
+ prompt_in = gr.Textbox(label="Prompt", value="Frog")
196
+ with gr.Accordion(label="Advanced settings", open=False):
197
+ added_prompt = gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece')
198
+ neg_prompt = gr.Textbox(label="Negative Prompt",value='dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
199
+ denoise_steps = gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1)
200
+ upsample_scale = gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1)
201
+ condition_scale = gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1)
202
+ classifier_free_guidance = gr.Slider(label="Classier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1)
203
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
204
+ submit_btn = gr.Button("Submit")
205
+ with gr.Column():
206
+ b_a_slider = ImageSlider(label="B/A result", position=0.5)
207
+ file_output = gr.File(label="Downloadable image result")
208
+
209
+ submit_btn.click(
210
+ fn = inference,
211
+ inputs = [
212
+ input_image, prompt_in,
213
+ added_prompt, neg_prompt,
214
+ denoise_steps,
215
+ upsample_scale, condition_scale,
216
+ classifier_free_guidance, seed
217
+ ],
218
+ outputs = [
219
+ b_a_slider,
220
+ file_output
221
+ ]
222
+ )
223
+ demo.queue(max_size=20).launch()