Spaces:
Paused
Paused
Commit
•
8ca8d03
1
Parent(s):
15183c0
Update app.py
Browse files
app.py
CHANGED
@@ -123,6 +123,8 @@ pipe.load_ip_adapter_instantid(face_adapter)
|
|
123 |
pipe.set_ip_adapter_scale(0.8)
|
124 |
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
125 |
zoe.to(device)
|
|
|
|
|
126 |
pipe.to(device)
|
127 |
|
128 |
last_lora = ""
|
@@ -202,10 +204,58 @@ def merge_incompatible_lora(full_path_lora, lora_scale):
|
|
202 |
)
|
203 |
del weights_sd
|
204 |
del lora_model
|
205 |
-
@spaces.GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, progress=gr.Progress(track_tqdm=True)):
|
207 |
global last_lora, last_merged, last_fused, pipe
|
208 |
-
|
209 |
face_image = center_crop_image_as_square(face_image)
|
210 |
try:
|
211 |
face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
|
@@ -216,7 +266,7 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
|
|
216 |
raise gr.Error("No face found in your image. Only face images work here. Try again")
|
217 |
|
218 |
for lora_list in lora_defaults:
|
219 |
-
if lora_list["model"] == sdxl_loras[
|
220 |
prompt_full = lora_list.get("prompt", None)
|
221 |
if(prompt_full):
|
222 |
prompt = prompt_full.replace("<subject>", prompt)
|
@@ -224,7 +274,7 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
|
|
224 |
|
225 |
print("Prompt:", prompt)
|
226 |
if(prompt == ""):
|
227 |
-
prompt = "
|
228 |
#prepare face zoe
|
229 |
with torch.no_grad():
|
230 |
image_zoe = zoe(face_image)
|
@@ -239,15 +289,15 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
|
|
239 |
# else:
|
240 |
# selected_state.index *= -1
|
241 |
#sdxl_loras = sdxl_loras_new
|
242 |
-
print("Selected State: ",
|
243 |
-
print(sdxl_loras[
|
244 |
if negative == "":
|
245 |
negative = None
|
246 |
|
247 |
if not selected_state:
|
248 |
raise gr.Error("You must select a LoRA")
|
249 |
-
repo_name = sdxl_loras[
|
250 |
-
weight_name = sdxl_loras[
|
251 |
|
252 |
full_path_lora = state_dicts[repo_name]["saved_name"]
|
253 |
loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
|
@@ -255,53 +305,8 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
|
|
255 |
print("Last LoRA: ", last_lora)
|
256 |
print("Current LoRA: ", repo_name)
|
257 |
print("Last fused: ", last_fused)
|
258 |
-
if last_lora != repo_name:
|
259 |
-
if(last_fused):
|
260 |
-
pipe.unfuse_lora()
|
261 |
-
pipe.unload_lora_weights()
|
262 |
-
pipe.load_lora_weights(loaded_state_dict)
|
263 |
-
pipe.fuse_lora(lora_scale)
|
264 |
-
last_fused = True
|
265 |
-
is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
|
266 |
-
if(is_pivotal):
|
267 |
-
#Add the textual inversion embeddings from pivotal tuning models
|
268 |
-
text_embedding_name = sdxl_loras[selected_state.index]["text_embedding_weights"]
|
269 |
-
embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
|
270 |
-
state_dict_embedding = load_file(embedding_path)
|
271 |
-
print(state_dict_embedding)
|
272 |
-
try:
|
273 |
-
pipe.unload_textual_inversion()
|
274 |
-
pipe.load_textual_inversion(state_dict_embedding["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
275 |
-
pipe.load_textual_inversion(state_dict_embedding["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
276 |
-
except:
|
277 |
-
pipe.unload_textual_inversion()
|
278 |
-
pipe.load_textual_inversion(state_dict_embedding["text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
279 |
-
pipe.load_textual_inversion(state_dict_embedding["text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
280 |
-
|
281 |
-
print("Processing prompt...")
|
282 |
-
conditioning, pooled = compel(prompt)
|
283 |
-
if(negative):
|
284 |
-
negative_conditioning, negative_pooled = compel(negative)
|
285 |
-
else:
|
286 |
-
negative_conditioning, negative_pooled = None, None
|
287 |
-
print("Processing image...")
|
288 |
-
|
289 |
-
image = pipe(
|
290 |
-
prompt_embeds=conditioning,
|
291 |
-
pooled_prompt_embeds=pooled,
|
292 |
-
negative_prompt_embeds=negative_conditioning,
|
293 |
-
negative_pooled_prompt_embeds=negative_pooled,
|
294 |
-
width=1024,
|
295 |
-
height=1024,
|
296 |
-
image_embeds=face_emb,
|
297 |
-
image=face_image,
|
298 |
-
strength=1-image_strength,
|
299 |
-
control_image=images,
|
300 |
-
num_inference_steps=20,
|
301 |
-
guidance_scale = guidance_scale,
|
302 |
-
controlnet_conditioning_scale=[face_strength, depth_control_scale],
|
303 |
-
).images[0]
|
304 |
|
|
|
305 |
last_lora = repo_name
|
306 |
return image, gr.update(visible=True)
|
307 |
|
|
|
123 |
pipe.set_ip_adapter_scale(0.8)
|
124 |
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
125 |
zoe.to(device)
|
126 |
+
|
127 |
+
original_pipe = copy.deepcopy(pipe)
|
128 |
pipe.to(device)
|
129 |
|
130 |
last_lora = ""
|
|
|
204 |
)
|
205 |
del weights_sd
|
206 |
del lora_model
|
207 |
+
@spaces.GPU
|
208 |
+
def generate_image(prompt, negative, face_emb, face_image, image_strength, images, guidance_scale, face_strength, depth_control_scale, last_lora, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index):
|
209 |
+
if last_lora != repo_name:
|
210 |
+
if(last_fused):
|
211 |
+
pipe.unfuse_lora()
|
212 |
+
pipe.unload_lora_weights()
|
213 |
+
pipe.load_lora_weights(loaded_state_dict)
|
214 |
+
pipe.fuse_lora(lora_scale)
|
215 |
+
last_fused = True
|
216 |
+
is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
|
217 |
+
if(is_pivotal):
|
218 |
+
#Add the textual inversion embeddings from pivotal tuning models
|
219 |
+
text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
|
220 |
+
embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
|
221 |
+
state_dict_embedding = load_file(embedding_path)
|
222 |
+
print(state_dict_embedding)
|
223 |
+
try:
|
224 |
+
pipe.unload_textual_inversion()
|
225 |
+
pipe.load_textual_inversion(state_dict_embedding["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
226 |
+
pipe.load_textual_inversion(state_dict_embedding["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
227 |
+
except:
|
228 |
+
pipe.unload_textual_inversion()
|
229 |
+
pipe.load_textual_inversion(state_dict_embedding["text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
230 |
+
pipe.load_textual_inversion(state_dict_embedding["text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
231 |
+
|
232 |
+
print("Processing prompt...")
|
233 |
+
conditioning, pooled = compel(prompt)
|
234 |
+
if(negative):
|
235 |
+
negative_conditioning, negative_pooled = compel(negative)
|
236 |
+
else:
|
237 |
+
negative_conditioning, negative_pooled = None, None
|
238 |
+
print("Processing image...")
|
239 |
+
image = pipe(
|
240 |
+
prompt_embeds=conditioning,
|
241 |
+
pooled_prompt_embeds=pooled,
|
242 |
+
negative_prompt_embeds=negative_conditioning,
|
243 |
+
negative_pooled_prompt_embeds=negative_pooled,
|
244 |
+
width=1024,
|
245 |
+
height=1024,
|
246 |
+
image_embeds=face_emb,
|
247 |
+
image=face_image,
|
248 |
+
strength=1-image_strength,
|
249 |
+
control_image=images,
|
250 |
+
num_inference_steps=20,
|
251 |
+
guidance_scale = guidance_scale,
|
252 |
+
controlnet_conditioning_scale=[face_strength, depth_control_scale],
|
253 |
+
).images[0]
|
254 |
+
return image
|
255 |
+
|
256 |
def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, progress=gr.Progress(track_tqdm=True)):
|
257 |
global last_lora, last_merged, last_fused, pipe
|
258 |
+
selected_state_index = selected_state.index
|
259 |
face_image = center_crop_image_as_square(face_image)
|
260 |
try:
|
261 |
face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
|
|
|
266 |
raise gr.Error("No face found in your image. Only face images work here. Try again")
|
267 |
|
268 |
for lora_list in lora_defaults:
|
269 |
+
if lora_list["model"] == sdxl_loras[selected_state_index]["repo"]:
|
270 |
prompt_full = lora_list.get("prompt", None)
|
271 |
if(prompt_full):
|
272 |
prompt = prompt_full.replace("<subject>", prompt)
|
|
|
274 |
|
275 |
print("Prompt:", prompt)
|
276 |
if(prompt == ""):
|
277 |
+
prompt = "a person"
|
278 |
#prepare face zoe
|
279 |
with torch.no_grad():
|
280 |
image_zoe = zoe(face_image)
|
|
|
289 |
# else:
|
290 |
# selected_state.index *= -1
|
291 |
#sdxl_loras = sdxl_loras_new
|
292 |
+
print("Selected State: ", selected_state_index)
|
293 |
+
print(sdxl_loras[selected_state_index]["repo"])
|
294 |
if negative == "":
|
295 |
negative = None
|
296 |
|
297 |
if not selected_state:
|
298 |
raise gr.Error("You must select a LoRA")
|
299 |
+
repo_name = sdxl_loras[selected_state_index]["repo"]
|
300 |
+
weight_name = sdxl_loras[selected_state_index]["weights"]
|
301 |
|
302 |
full_path_lora = state_dicts[repo_name]["saved_name"]
|
303 |
loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
|
|
|
305 |
print("Last LoRA: ", last_lora)
|
306 |
print("Current LoRA: ", repo_name)
|
307 |
print("Last fused: ", last_fused)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
|
309 |
+
image = generate_image(prompt, negative, face_emb, face_image, image_strength, images, guidance_scale, face_strength, depth_control_scale, last_lora, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index)
|
310 |
last_lora = repo_name
|
311 |
return image, gr.update(visible=True)
|
312 |
|