tsqn commited on
Commit
2d0ea67
1 Parent(s): f721c2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -87
app.py CHANGED
@@ -2,7 +2,6 @@ import spaces
2
 
3
  import torch
4
  import torchvision.transforms.functional as TF
5
- import tomesd
6
  import numpy as np
7
  import random
8
  import os
@@ -22,10 +21,6 @@ from aspect_ratio_template import aspect_ratios
22
 
23
  # global variable
24
  base_model_path = 'SG161222/RealVisXL_V5.0'
25
- torch.backends.cuda.matmul.allow_tf32 = True
26
- torch.backends.cudnn.allow_tf32 = True
27
- torch.backends.cudnn.benchmark = True
28
- torch.set_grad_enabled(False)
29
  face_detector = FaceAnalysis2(providers=['CPUExecutionProvider', 'CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
30
  face_detector.prepare(ctx_id=0, det_size=(640, 640))
31
 
@@ -74,6 +69,9 @@ pipe.text_encoder = pipe.text_encoder.to(device=device, dtype=torch_dtype)
74
  pipe.text_encoder_2 = pipe.text_encoder_2.to(device=device, dtype=torch_dtype)
75
  pipe.vae = pipe.vae.to(device=device, dtype=torch_dtype)
76
 
 
 
 
77
  pipe.load_photomaker_adapter(
78
  os.path.dirname(photomaker_ckpt),
79
  subfolder="",
@@ -88,12 +86,9 @@ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
88
  pipe.fuse_lora()
89
  pipe.to(device)
90
 
 
91
  pipe.enable_vae_slicing()
92
  pipe.enable_vae_tiling()
93
- pipe.enable_xformers_memory_efficient_attention()
94
-
95
- torch.cuda.empty_cache()
96
-
97
  @spaces.GPU(duration=120)
98
  def generate_image(
99
  upload_images,
@@ -112,85 +107,82 @@ def generate_image(
112
  adapter_conditioning_factor,
113
  progress=gr.Progress(track_tqdm=True)
114
  ):
115
- with torch.inference_mode():
116
- torch.cuda.empty_cache()
117
- if use_doodle:
118
- sketch_image = sketch_image["composite"]
119
- r, g, b, a = sketch_image.split()
120
- sketch_image = a.convert("RGB")
121
- sketch_image = TF.to_tensor(sketch_image) > 0.5 # Inversion
122
- sketch_image = TF.to_pil_image(sketch_image.to(torch.float32))
123
- adapter_conditioning_scale = adapter_conditioning_scale
124
- adapter_conditioning_factor = adapter_conditioning_factor
125
- else:
126
- adapter_conditioning_scale = 0.
127
- adapter_conditioning_factor = 0.
128
- sketch_image = None
129
-
130
- # check the trigger word
131
- image_token_id = pipe.tokenizer.convert_tokens_to_ids(pipe.trigger_word)
132
- input_ids = pipe.tokenizer.encode(prompt)
133
- if image_token_id not in input_ids:
134
- raise gr.Error(f"Cannot find the trigger word '{pipe.trigger_word}' in text prompt! Please refer to step 2️⃣")
135
-
136
- if input_ids.count(image_token_id) > 1:
137
- raise gr.Error(f"Cannot use multiple trigger words '{pipe.trigger_word}' in text prompt!")
138
-
139
- # determine output dimensions by the aspect ratio
140
- output_w, output_h = aspect_ratios[aspect_ratio_name]
141
- print(f"[Debug] Generate image using aspect ratio [{aspect_ratio_name}] => {output_w} x {output_h}")
142
-
143
- # apply the style template
144
- prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
145
-
146
- if upload_images is None:
147
- raise gr.Error(f"Cannot find any input face image! Please refer to step 1️⃣")
148
-
149
- input_id_images = []
150
- for img in upload_images:
151
- input_id_images.append(load_image(img))
152
-
153
- id_embed_list = []
154
-
155
- for img in input_id_images:
156
- img = np.array(img)
157
- img = img[:, :, ::-1]
158
- faces = analyze_faces(face_detector, img)
159
- if len(faces) > 0:
160
- id_embed_list.append(torch.from_numpy((faces[0]['embedding'])))
161
-
162
- if len(id_embed_list) == 0:
163
- raise gr.Error(f"No face detected, please update the input face image(s)")
164
-
165
- id_embeds = torch.stack(id_embed_list)
166
 
167
- generator = torch.Generator(device=device).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
168
 
169
- print("Start inference...")
170
- print(f"[Debug] Seed: {seed}")
171
- print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
172
- start_merge_step = int(float(style_strength_ratio) / 100 * num_steps)
173
- if start_merge_step > 30:
174
- start_merge_step = 30
175
- print(start_merge_step)
176
- tomesd.apply_patch(pipe, ratio=0.5)
177
- images = pipe(
178
- prompt=prompt,
179
- width=output_w,
180
- height=output_h,
181
- input_id_images=input_id_images,
182
- negative_prompt=negative_prompt,
183
- num_images_per_prompt=num_outputs,
184
- num_inference_steps=num_steps,
185
- start_merge_step=start_merge_step,
186
- generator=generator,
187
- guidance_scale=guidance_scale,
188
- id_embeds=id_embeds,
189
- image=sketch_image,
190
- adapter_conditioning_scale=adapter_conditioning_scale,
191
- adapter_conditioning_factor=adapter_conditioning_factor,
192
- ).images
193
- return images, gr.update(visible=True)
 
 
 
194
 
195
  def swap_to_gallery(images):
196
  return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
@@ -448,4 +440,4 @@ with gr.Blocks(css=css) as demo:
448
 
449
  gr.Markdown(article)
450
 
451
- demo.launch()
 
2
 
3
  import torch
4
  import torchvision.transforms.functional as TF
 
5
  import numpy as np
6
  import random
7
  import os
 
21
 
22
  # global variable
23
  base_model_path = 'SG161222/RealVisXL_V5.0'
 
 
 
 
24
  face_detector = FaceAnalysis2(providers=['CPUExecutionProvider', 'CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
25
  face_detector.prepare(ctx_id=0, det_size=(640, 640))
26
 
 
69
  pipe.text_encoder_2 = pipe.text_encoder_2.to(device=device, dtype=torch_dtype)
70
  pipe.vae = pipe.vae.to(device=device, dtype=torch_dtype)
71
 
72
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead")
73
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead")
74
+
75
  pipe.load_photomaker_adapter(
76
  os.path.dirname(photomaker_ckpt),
77
  subfolder="",
 
86
  pipe.fuse_lora()
87
  pipe.to(device)
88
 
89
+
90
  pipe.enable_vae_slicing()
91
  pipe.enable_vae_tiling()
 
 
 
 
92
  @spaces.GPU(duration=120)
93
  def generate_image(
94
  upload_images,
 
107
  adapter_conditioning_factor,
108
  progress=gr.Progress(track_tqdm=True)
109
  ):
110
+ if use_doodle:
111
+ sketch_image = sketch_image["composite"]
112
+ r, g, b, a = sketch_image.split()
113
+ sketch_image = a.convert("RGB")
114
+ sketch_image = TF.to_tensor(sketch_image) > 0.5 # Inversion
115
+ sketch_image = TF.to_pil_image(sketch_image.to(torch.float32))
116
+ adapter_conditioning_scale = adapter_conditioning_scale
117
+ adapter_conditioning_factor = adapter_conditioning_factor
118
+ else:
119
+ adapter_conditioning_scale = 0.
120
+ adapter_conditioning_factor = 0.
121
+ sketch_image = None
122
+
123
+ # check the trigger word
124
+ image_token_id = pipe.tokenizer.convert_tokens_to_ids(pipe.trigger_word)
125
+ input_ids = pipe.tokenizer.encode(prompt)
126
+ if image_token_id not in input_ids:
127
+ raise gr.Error(f"Cannot find the trigger word '{pipe.trigger_word}' in text prompt! Please refer to step 2️⃣")
128
+
129
+ if input_ids.count(image_token_id) > 1:
130
+ raise gr.Error(f"Cannot use multiple trigger words '{pipe.trigger_word}' in text prompt!")
131
+
132
+ # determine output dimensions by the aspect ratio
133
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
134
+ print(f"[Debug] Generate image using aspect ratio [{aspect_ratio_name}] => {output_w} x {output_h}")
135
+
136
+ # apply the style template
137
+ prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
138
+
139
+ if upload_images is None:
140
+ raise gr.Error(f"Cannot find any input face image! Please refer to step 1️⃣")
141
+
142
+ input_id_images = []
143
+ for img in upload_images:
144
+ input_id_images.append(load_image(img))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ id_embed_list = []
147
+
148
+ for img in input_id_images:
149
+ img = np.array(img)
150
+ img = img[:, :, ::-1]
151
+ faces = analyze_faces(face_detector, img)
152
+ if len(faces) > 0:
153
+ id_embed_list.append(torch.from_numpy((faces[0]['embedding'])))
154
+
155
+ if len(id_embed_list) == 0:
156
+ raise gr.Error(f"No face detected, please update the input face image(s)")
157
 
158
+ id_embeds = torch.stack(id_embed_list)
159
+
160
+ generator = torch.Generator(device=device).manual_seed(seed)
161
+
162
+ print("Start inference...")
163
+ print(f"[Debug] Seed: {seed}")
164
+ print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
165
+ start_merge_step = int(float(style_strength_ratio) / 100 * num_steps)
166
+ if start_merge_step > 30:
167
+ start_merge_step = 30
168
+ print(start_merge_step)
169
+ images = pipe(
170
+ prompt=prompt,
171
+ width=output_w,
172
+ height=output_h,
173
+ input_id_images=input_id_images,
174
+ negative_prompt=negative_prompt,
175
+ num_images_per_prompt=num_outputs,
176
+ num_inference_steps=num_steps,
177
+ start_merge_step=start_merge_step,
178
+ generator=generator,
179
+ guidance_scale=guidance_scale,
180
+ id_embeds=id_embeds,
181
+ image=sketch_image,
182
+ adapter_conditioning_scale=adapter_conditioning_scale,
183
+ adapter_conditioning_factor=adapter_conditioning_factor,
184
+ ).images
185
+ return images, gr.update(visible=True)
186
 
187
  def swap_to_gallery(images):
188
  return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
 
440
 
441
  gr.Markdown(article)
442
 
443
+ demo.launch()