yy1636 commited on
Commit
1919c24
·
verified ·
1 Parent(s): 0b0244d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +697 -0
app.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import mediapipe as mp
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionControlNetInpaintPipeline
9
+ from transformers import AutoTokenizer
10
+ import base64
11
+ import requests
12
+ import json
13
+ from rembg import remove
14
+ from scipy import ndimage
15
+ from moviepy.editor import ImageSequenceClip
16
+ from tqdm import tqdm
17
+ import os
18
+ import shutil
19
+ import time
20
+ from huggingface_hub import snapshot_download
21
+ import subprocess
22
+ import sys
23
+
24
+ def download_liveportrait():
25
+ """
26
+ Clone the LivePortrait repository and prepare its dependencies.
27
+ """
28
+ liveportrait_path = "./LivePortrait"
29
+ try:
30
+ if not os.path.exists(liveportrait_path):
31
+ print("Cloning LivePortrait repository...")
32
+ os.system(f"git clone https://github.com/KwaiVGI/LivePortrait.git {liveportrait_path}")
33
+
34
+ # 安装依赖
35
+ os.chdir(liveportrait_path)
36
+ print("Installing LivePortrait dependencies...")
37
+ os.system("pip install -r requirements.txt")
38
+
39
+ # 构建 MultiScaleDeformableAttention 模块
40
+ dependency_path = "src/utils/dependencies/XPose/models/UniPose/ops"
41
+ os.chdir(dependency_path)
42
+ print("Building MultiScaleDeformableAttention...")
43
+ os.system("python setup.py build")
44
+ os.system("python setup.py install")
45
+
46
+ # 确保模块路径可用
47
+ module_path = os.path.abspath(dependency_path)
48
+ if module_path not in sys.path:
49
+ sys.path.append(module_path)
50
+
51
+ # 返回 LivePortrait 目录
52
+ os.chdir("../../../../../../../")
53
+ print("LivePortrait setup completed")
54
+ except Exception as e:
55
+ print("Failed to initialize LivePortrait:", e)
56
+ raise
57
+
58
+ def download_huggingface_resources():
59
+ """
60
+ Download additional necessary resources from Hugging Face using the CLI.
61
+ """
62
+ try:
63
+ local_dir = "./pretrained_weights"
64
+ os.makedirs(local_dir, exist_ok=True)
65
+
66
+ # Use the Hugging Face CLI for downloading
67
+ cmd = [
68
+ "huggingface-cli", "download",
69
+ "KwaiVGI/LivePortrait",
70
+ "--local-dir", local_dir,
71
+ "--exclude", "*.git*", "README.md", "docs"
72
+ ]
73
+ print("Executing command:", " ".join(cmd))
74
+ subprocess.run(cmd, check=True)
75
+
76
+ print("Resources successfully downloaded to:", local_dir)
77
+ except subprocess.CalledProcessError as e:
78
+ print("Error during Hugging Face CLI download:", e)
79
+ raise
80
+ except Exception as e:
81
+ print("General error in downloading resources:", e)
82
+ raise
83
+
84
+ def get_project_root():
85
+ """Get the root directory of the current project."""
86
+ return os.path.abspath(os.path.dirname(__file__))
87
+
88
+ # Ensure working directory is project root
89
+ os.chdir(get_project_root())
90
+
91
+ # Initialize the necessary models and components
92
+ mp_pose = mp.solutions.pose
93
+ mp_drawing = mp.solutions.drawing_utils
94
+
95
+ # Load ControlNet model
96
+ controlnet = ControlNetModel.from_pretrained('lllyasviel/sd-controlnet-openpose', torch_dtype=torch.float16)
97
+
98
+ # Load Stable Diffusion model with ControlNet
99
+ pipe_controlnet = StableDiffusionControlNetPipeline.from_pretrained(
100
+ 'runwayml/stable-diffusion-v1-5',
101
+ controlnet=controlnet,
102
+ torch_dtype=torch.float16
103
+ )
104
+
105
+ # Load Inpaint Controlnet
106
+ pipe_inpaint_controlnet = StableDiffusionControlNetInpaintPipeline.from_pretrained(
107
+ "runwayml/stable-diffusion-inpainting",
108
+ controlnet=controlnet,
109
+ torch_dtype=torch.float16
110
+ )
111
+
112
+ # Move to GPU if available
113
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
114
+ pipe_controlnet.to(device)
115
+ pipe_controlnet.enable_attention_slicing()
116
+ pipe_inpaint_controlnet.to(device)
117
+ pipe_inpaint_controlnet.enable_attention_slicing()
118
+
119
+ def resize_to_multiple_of_64(width, height):
120
+ return (width // 64) * 64, (height // 64) * 64
121
+
122
+ def expand_mask(mask, kernel_size):
123
+ mask_array = np.array(mask)
124
+ structuring_element = np.ones((kernel_size, kernel_size), dtype=np.uint8)
125
+ expanded_mask_array = ndimage.binary_dilation(
126
+ mask_array, structure=structuring_element
127
+ ).astype(np.uint8) * 255
128
+ return Image.fromarray(expanded_mask_array)
129
+
130
+ def crop_face_to_square(image_rgb, padding_ratio=0.2):
131
+ """
132
+ Detects the face in the input image and crops an enlarged square region around it.
133
+ """
134
+ face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
135
+ gray_image = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
136
+ faces = face_cascade.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
137
+
138
+ if len(faces) == 0:
139
+ print("No face detected.")
140
+ return None
141
+
142
+ x, y, w, h = faces[0]
143
+ center_x, center_y = x + w // 2, y + h // 2
144
+ side_length = max(w, h)
145
+ padded_side_length = int(side_length * (1 + padding_ratio))
146
+ half_side = padded_side_length // 2
147
+
148
+ top_left_x = max(center_x - half_side, 0)
149
+ top_left_y = max(center_y - half_side, 0)
150
+ bottom_right_x = min(center_x + half_side, image_rgb.shape[1])
151
+ bottom_right_y = min(center_y + half_side, image_rgb.shape[0])
152
+
153
+ cropped_image = image_rgb[top_left_y:bottom_right_y, top_left_x:bottom_right_x]
154
+ resized_image = cv2.resize(cropped_image, (768, 768), interpolation=cv2.INTER_AREA)
155
+
156
+ return resized_image
157
+
158
+ def spirit_animal_baseline(image_path, num_images = 4):
159
+
160
+ image = cv2.imread(image_path)
161
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
162
+
163
+ image_rgb = crop_face_to_square(image_rgb)
164
+
165
+ original_height, original_width, _ = image_rgb.shape
166
+ aspect_ratio = original_width / original_height
167
+
168
+ if aspect_ratio > 1:
169
+ gen_width = 768
170
+ gen_height = int(gen_width / aspect_ratio)
171
+ else:
172
+ gen_height = 768
173
+ gen_width = int(gen_height * aspect_ratio)
174
+
175
+ gen_width, gen_height = resize_to_multiple_of_64(gen_width, gen_height)
176
+
177
+ with mp_pose.Pose(static_image_mode=True) as pose:
178
+ results = pose.process(image_rgb)
179
+
180
+ if results.pose_landmarks:
181
+ annotated_image = image_rgb.copy()
182
+ mp_drawing.draw_landmarks(
183
+ annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS
184
+ )
185
+ else:
186
+ print("No pose detected.")
187
+ return "No pose detected.", []
188
+
189
+ pose_image = np.zeros_like(image_rgb)
190
+ for connection in mp_pose.POSE_CONNECTIONS:
191
+ start_idx, end_idx = connection
192
+ start, end = results.pose_landmarks.landmark[start_idx], results.pose_landmarks.landmark[end_idx]
193
+ if start.visibility > 0.5 and end.visibility > 0.5:
194
+ x1, y1 = int(start.x * pose_image.shape[1]), int(start.y * pose_image.shape[0])
195
+ x2, y2 = int(end.x * pose_image.shape[1]), int(end.y * pose_image.shape[0])
196
+ cv2.line(pose_image, (x1, y1), (x2, y2), (255, 255, 255), 2)
197
+
198
+ pose_pil = Image.fromarray(cv2.resize(pose_image, (gen_width, gen_height), interpolation=cv2.INTER_LANCZOS4))
199
+
200
+ base64_image = base64.b64encode(cv2.imencode('.jpg', image_rgb)[1]).decode()
201
+ api_key = "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"
202
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
203
+ payload = {
204
+ "model": "gpt-4o-mini",
205
+ "messages": [
206
+ {
207
+ "role": "user",
208
+ "content": [
209
+ {"type": "text", "text": "Based on the provided image, think of one spirit animal that is right for the person, and answer in the following format: An ultra-realistic, highly detailed photograph of a single {animal} with facial features characterized by {description}, standing upright in a human-like pose, looking directly at the camera, against a solid, neutral background. Generate one sentence without any other responses or numbering."},
210
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
211
+ ]
212
+ }
213
+ ],
214
+ "max_tokens": 100
215
+ }
216
+
217
+ response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
218
+ prompt = response.json()['choices'][0]['message']['content'] if 'choices' in response.json() else "A majestic animal"
219
+
220
+ num_images = num_images
221
+ generated_images = []
222
+ with torch.no_grad():
223
+ with torch.autocast(device_type=device.type):
224
+ for _ in range(num_images):
225
+ images = pipe_controlnet(
226
+ prompt=prompt,
227
+ negative_prompt="multiple heads, extra limbs, duplicate faces, mutated anatomy, disfigured, blurry",
228
+ num_inference_steps=20,
229
+ image=pose_pil,
230
+ guidance_scale=5,
231
+ width=gen_width,
232
+ height=gen_height,
233
+ ).images
234
+ generated_images.append(images[0])
235
+
236
+ return prompt, generated_images
237
+
238
+ def spirit_animal_with_background(image_path, num_images = 4):
239
+
240
+ image = cv2.imread(image_path)
241
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
242
+
243
+ # image_rgb = crop_face_to_square(image_rgb)
244
+
245
+ original_height, original_width, _ = image_rgb.shape
246
+ aspect_ratio = original_width / original_height
247
+
248
+ if aspect_ratio > 1:
249
+ gen_width = 768
250
+ gen_height = int(gen_width / aspect_ratio)
251
+ else:
252
+ gen_height = 768
253
+ gen_width = int(gen_height * aspect_ratio)
254
+
255
+ gen_width, gen_height = resize_to_multiple_of_64(gen_width, gen_height)
256
+
257
+ with mp_pose.Pose(static_image_mode=True) as pose:
258
+ results = pose.process(image_rgb)
259
+
260
+ if results.pose_landmarks:
261
+ annotated_image = image_rgb.copy()
262
+ mp_drawing.draw_landmarks(
263
+ annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS
264
+ )
265
+ else:
266
+ print("No pose detected.")
267
+ return "No pose detected.", []
268
+
269
+ pose_image = np.zeros_like(image_rgb)
270
+ for connection in mp_pose.POSE_CONNECTIONS:
271
+ start_idx, end_idx = connection
272
+ start, end = results.pose_landmarks.landmark[start_idx], results.pose_landmarks.landmark[end_idx]
273
+ if start.visibility > 0.5 and end.visibility > 0.5:
274
+ x1, y1 = int(start.x * pose_image.shape[1]), int(start.y * pose_image.shape[0])
275
+ x2, y2 = int(end.x * pose_image.shape[1]), int(end.y * pose_image.shape[0])
276
+ cv2.line(pose_image, (x1, y1), (x2, y2), (255, 255, 255), 2)
277
+
278
+ pose_pil = Image.fromarray(cv2.resize(pose_image, (gen_width, gen_height), interpolation=cv2.INTER_LANCZOS4))
279
+
280
+ base64_image = base64.b64encode(cv2.imencode('.jpg', image_rgb)[1]).decode()
281
+ api_key = "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"
282
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
283
+ payload = {
284
+ "model": "gpt-4o-mini",
285
+ "messages": [
286
+ {
287
+ "role": "user",
288
+ "content": [
289
+ {"type": "text", "text": "Based on the provided image, think of one spirit animal that is right for the person, and answer in the following format: An ultra-realistic, highly detailed photograph of a single {animal} with facial features characterized by {description}, standing upright in a human-like pose, looking directly at the camera, against a solid, neutral background. Generate one sentence without any other responses or numbering."},
290
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
291
+ ]
292
+ }
293
+ ],
294
+ "max_tokens": 100
295
+ }
296
+
297
+ response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
298
+ prompt = response.json()['choices'][0]['message']['content'] if 'choices' in response.json() else "A majestic animal"
299
+
300
+ mask_image = remove(Image.fromarray(image_rgb))
301
+ initial_mask = mask_image.split()[-1].convert('L')
302
+
303
+ kernel_size = min(gen_width, gen_height) // 15
304
+ expanded_mask = expand_mask(initial_mask, kernel_size)
305
+
306
+ num_images = num_images
307
+ generated_images = []
308
+ with torch.no_grad():
309
+ with torch.autocast(device_type=device.type):
310
+ for _ in range(num_images):
311
+ images = pipe_inpaint_controlnet(
312
+ prompt=prompt,
313
+ negative_prompt="multiple heads, extra limbs, duplicate faces, mutated anatomy, disfigured, blurry",
314
+ num_inference_steps=20,
315
+ image=Image.fromarray(image_rgb),
316
+ mask_image=expanded_mask,
317
+ control_image=pose_pil,
318
+ width=gen_width,
319
+ height=gen_height,
320
+ guidance_scale=5,
321
+ ).images
322
+ generated_images.append(images[0])
323
+
324
+ return prompt, generated_images
325
+
326
+ def generate_multiple_animals(image_path, keep_background=True, num_images = 4):
327
+
328
+ image = cv2.imread(image_path)
329
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
330
+
331
+ image_rgb = crop_face_to_square(image_rgb)
332
+
333
+ original_image = Image.fromarray(image_rgb)
334
+ original_width, original_height = original_image.size
335
+
336
+ aspect_ratio = original_width / original_height
337
+ if aspect_ratio > 1:
338
+ gen_width = 768
339
+ gen_height = int(gen_width / aspect_ratio)
340
+ else:
341
+ gen_height = 768
342
+ gen_width = int(gen_height * aspect_ratio)
343
+
344
+ gen_width, gen_height = resize_to_multiple_of_64(gen_width, gen_height)
345
+
346
+ base64_image = base64.b64encode(cv2.imencode('.jpg', image_rgb)[1]).decode()
347
+ api_key = "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"
348
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
349
+ payload = {
350
+ "model": "gpt-4o-mini",
351
+ "messages": [
352
+ {
353
+ "role": "user",
354
+ "content": [
355
+ {
356
+ "type": "text",
357
+ "text": "Based on the provided image, think of " + str(num_images) + " different spirit animals that are right for the person, and answer in the following format for each: An ultra-realistic, highly detailed photograph of a {animal} with facial features characterized by {description}, standing upright in a human-like pose, looking directly at the camera, against a solid, neutral background. Generate these sentences without any other responses or numbering. For the animal choose between owl, bear, fox, koala, lion, dog"
358
+ },
359
+ {
360
+ "type": "image_url",
361
+ "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
362
+ }
363
+ ]
364
+ }
365
+ ],
366
+ "max_tokens": 500
367
+ }
368
+
369
+ response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
370
+ response_json = response.json()
371
+
372
+ if 'choices' in response_json and len(response_json['choices']) > 0:
373
+ content = response_json['choices'][0]['message']['content']
374
+ prompts = [prompt.strip() for prompt in content.strip().split('.') if prompt.strip()]
375
+ negative_prompt = (
376
+ "multiple heads, extra limbs, duplicate faces, mutated anatomy, disfigured, "
377
+ "blurry, deformed, text, watermark, logo, low resolution"
378
+ )
379
+ formatted_prompts = "\n".join(f"{i+1}. {prompt}" for i, prompt in enumerate(prompts))
380
+
381
+ with mp_pose.Pose(static_image_mode=True) as pose:
382
+ results = pose.process(image_rgb)
383
+
384
+ if results.pose_landmarks:
385
+ annotated_image = image_rgb.copy()
386
+ mp_drawing.draw_landmarks(
387
+ annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS
388
+ )
389
+ else:
390
+ print("No pose detected.")
391
+ return "No pose detected.", []
392
+
393
+ pose_image = np.zeros_like(image_rgb)
394
+ for connection in mp_pose.POSE_CONNECTIONS:
395
+ start_idx, end_idx = connection
396
+ start, end = results.pose_landmarks.landmark[start_idx], results.pose_landmarks.landmark[end_idx]
397
+ if start.visibility > 0.5 and end.visibility > 0.5:
398
+ x1, y1 = int(start.x * pose_image.shape[1]), int(start.y * pose_image.shape[0])
399
+ x2, y2 = int(end.x * pose_image.shape[1]), int(end.y * pose_image.shape[0])
400
+ cv2.line(pose_image, (x1, y1), (x2, y2), (255, 255, 255), 2)
401
+
402
+ pose_pil = Image.fromarray(cv2.resize(pose_image, (gen_width, gen_height), interpolation=cv2.INTER_LANCZOS4))
403
+
404
+ if keep_background:
405
+ mask_image = remove(original_image)
406
+ initial_mask = mask_image.split()[-1].convert('L')
407
+ expanded_mask = expand_mask(initial_mask, kernel_size=min(gen_width, gen_height) // 15)
408
+ else:
409
+ expanded_mask = None
410
+
411
+ generated_images = []
412
+
413
+ if keep_background:
414
+ with torch.no_grad():
415
+ with torch.amp.autocast("cuda"):
416
+ for prompt in prompts:
417
+ images = pipe_inpaint_controlnet(
418
+ prompt=prompt,
419
+ negative_prompt=negative_prompt,
420
+ num_inference_steps=20,
421
+ image=Image.fromarray(image_rgb),
422
+ mask_image=expanded_mask,
423
+ control_image=pose_pil,
424
+ width=gen_width,
425
+ height=gen_height,
426
+ guidance_scale=5,
427
+ ).images
428
+ generated_images.append(images[0])
429
+ else:
430
+ with torch.no_grad():
431
+ with torch.amp.autocast("cuda"):
432
+ for prompt in prompts:
433
+ images = pipe_controlnet(
434
+ prompt=prompt,
435
+ negative_prompt=negative_prompt,
436
+ num_inference_steps=20,
437
+ image=pose_pil,
438
+ guidance_scale=5,
439
+ width=gen_width,
440
+ height=gen_height,
441
+ ).images
442
+ generated_images.append(images[0])
443
+
444
+ return formatted_prompts, generated_images
445
+
446
+ def wait_for_file(file_path, timeout=500):
447
+ """
448
+ Wait for a file to be created, with a specified timeout.
449
+ Args:
450
+ file_path (str): The path of the file to wait for.
451
+ timeout (int): Maximum time to wait in seconds.
452
+ Returns:
453
+ bool: True if the file is created, False if timeout occurs.
454
+ """
455
+ start_time = time.time()
456
+ while not os.path.exists(file_path):
457
+ if time.time() - start_time > timeout:
458
+ return False
459
+ time.sleep(0.5) # Check every 0.5 seconds
460
+ return True
461
+
462
+ def generate_spirit_animal_video(driving_video_path):
463
+ os.chdir(".")
464
+ try:
465
+ # Step 1: Extract the first frame
466
+ cap = cv2.VideoCapture(driving_video_path)
467
+ if not cap.isOpened():
468
+ print("Error: Unable to open video.")
469
+ return None
470
+
471
+ ret, frame = cap.read()
472
+ cap.release()
473
+ if not ret:
474
+ print("Error: Unable to read the first frame.")
475
+ return None
476
+
477
+ # Save the first frame
478
+ first_frame_path = "./first_frame.jpg"
479
+ cv2.imwrite(first_frame_path, frame)
480
+ print(f"First frame saved to: {first_frame_path}")
481
+
482
+ # Generate spirit animal image
483
+ _, input_image = generate_multiple_animals(first_frame_path, True, 1)
484
+ if input_image is None or not input_image:
485
+ print("Error: Spirit animal generation failed.")
486
+ return None
487
+
488
+ spirit_animal_path = "./animal.jpeg"
489
+ cv2.imwrite(spirit_animal_path, cv2.cvtColor(np.array(input_image[0]), cv2.COLOR_RGB2BGR))
490
+ print(f"Spirit animal image saved to: {spirit_animal_path}")
491
+
492
+ # Step 3: Run inference
493
+ output_path = "./animations/animal--uploaded_video_compressed.mp4"
494
+ script_path = os.path.abspath("../LivePortrait/inference_animals.py")
495
+
496
+ if not os.path.exists(script_path):
497
+ print(f"Error: Inference script not found at {script_path}.")
498
+ return None
499
+
500
+ command = f"python {script_path} -s {spirit_animal_path} -d {driving_video_path} --driving_multiplier 1.75 --no_flag_stitching"
501
+ print(f"Running command: {command}")
502
+ result = os.system(command)
503
+
504
+ if result != 0:
505
+ print(f"Error: Command failed with exit code {result}.")
506
+ return None
507
+
508
+ # Verify output file exists
509
+ if not os.path.exists(output_path):
510
+ print(f"Error: Expected output video not found at {output_path}.")
511
+ return None
512
+
513
+ print(f"Output video generated at: {output_path}")
514
+ return output_path
515
+ except Exception as e:
516
+ print(f"Error occurred: {e}")
517
+ return None
518
+
519
+ def generate_spirit_animal(image, animal_type, background):
520
+ if animal_type == "Single Animal":
521
+ if background == "Preserve Background":
522
+ prompt, generated_images = spirit_animal_with_background(image)
523
+ else:
524
+ prompt, generated_images = spirit_animal_baseline(image)
525
+ elif animal_type == "Multiple Animals":
526
+ if background == "Preserve Background":
527
+ prompt, generated_images = generate_multiple_animals(image, keep_background=True)
528
+ else:
529
+ prompt, generated_images = generate_multiple_animals(image, keep_background=False)
530
+ return prompt, generated_images
531
+
532
+ def compress_video(input_path, output_path, target_size_mb):
533
+ target_size_bytes = target_size_mb * 1024 * 1024
534
+ temp_output = "./temp_compressed.mp4"
535
+
536
+ cap = cv2.VideoCapture(input_path)
537
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4 编码
538
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
539
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
540
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
541
+
542
+ writer = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
543
+ while cap.isOpened():
544
+ ret, frame = cap.read()
545
+ if not ret:
546
+ break
547
+ writer.write(frame)
548
+
549
+ cap.release()
550
+ writer.release()
551
+
552
+ current_size = os.path.getsize(temp_output)
553
+ if current_size > target_size_bytes:
554
+ bitrate = int(target_size_bytes * 8 / (current_size / target_size_bytes)) # 按比例缩减比特率
555
+ os.system(f"ffmpeg -i {temp_output} -b:v {bitrate} -y {output_path}")
556
+ os.remove(temp_output)
557
+ else:
558
+ shutil.move(temp_output, output_path)
559
+
560
+ def process_video(video_file):
561
+
562
+ # 初始化 LivePortrait
563
+ try:
564
+ download_liveportrait()
565
+ except Exception as e:
566
+ print("Failed to initialize LivePortrait:", e)
567
+ return gr.update(value=None, visible=False)
568
+
569
+ # 下载 Hugging Face 资源
570
+ try:
571
+ download_huggingface_resources()
572
+ except Exception as e:
573
+ print("Failed to download Hugging Face resources:", e)
574
+ return gr.update(value=None, visible=False)
575
+
576
+ compressed_path = "./uploaded_video_compressed.mp4"
577
+ compress_video(video_file, compressed_path, target_size_mb=1)
578
+ print(f"Compressed and moved video to: {compressed_path}")
579
+
580
+ output_video_path = generate_spirit_animal_video(compressed_path)
581
+
582
+ # Wait until the output video is generated
583
+ timeout = 6000 # Timeout in seconds
584
+ if not wait_for_file(output_video_path, timeout=timeout):
585
+ print("Timeout occurred while waiting for video generation.")
586
+ return gr.update(value=None, visible=False) # Hide output if failed
587
+
588
+ # Return the generated video path
589
+ print(f"Output video is ready: {output_video_path}")
590
+ return gr.update(value=output_video_path, visible=True) # Show video
591
+
592
+
593
+ # Custom CSS styling for the interface
594
+ css = """
595
+ #title-container {
596
+ font-family: 'Arial', sans-serif;
597
+ color: #4a4a4a;
598
+ text-align: center;
599
+ margin-bottom: 20px;
600
+ }
601
+ #title-container h1 {
602
+ font-size: 2.5em;
603
+ font-weight: bold;
604
+ color: #ff9900;
605
+ }
606
+ #title-container h2 {
607
+ font-size: 1.2em;
608
+ color: #6c757d;
609
+ }
610
+ #intro-text {
611
+ font-size: 1em;
612
+ color: #6c757d;
613
+ margin: 50px;
614
+ text-align: center;
615
+ font-style: italic;
616
+ }
617
+ #prompt-output {
618
+ font-family: 'Courier New', monospace;
619
+ color: #5a5a5a;
620
+ font-size: 1.1em;
621
+ padding: 10px;
622
+ background-color: #f9f9f9;
623
+ border: 1px solid #ddd;
624
+ border-radius: 5px;
625
+ margin-top: 10px;
626
+ }
627
+ """
628
+
629
+ # Title and description
630
+ title_html = """
631
+ <div id="title-container">
632
+ <h1>Spirit Animal Generator</h1>
633
+ <h2>Create your unique spirit animal with AI-assisted image generation.</h2>
634
+ </div>
635
+ """
636
+
637
+ description_text = """
638
+ ### Project Overview
639
+ Welcome to the Spirit Animal Generator! This tool leverages advanced AI technologies to create unique visualizations of spirit animals from both videos and images.
640
+ #### Key Features:
641
+ 1. **Video Transformation**: Upload a driving video to generate a creative spirit animal animation.
642
+ 2. **Image Creation**: Upload an image and customize the spirit animal type and background options.
643
+ 3. **AI-Powered Prompting**: OpenAI's GPT generates descriptive prompts for each input.
644
+ 4. **High-Quality Outputs**: Generated using Stable Diffusion and ControlNet for stunning visuals.
645
+ ---
646
+ ### How It Works:
647
+ 1. **Upload Your Media**:
648
+ - Videos: Ensure the file is in MP4 format.
649
+ - Images: Use clear, high-resolution photos for better results.
650
+ 2. **Customize Options**:
651
+ - For images, select the type of animal and background settings.
652
+ 3. **View Your Results**:
653
+ - Videos will be transformed into animations.
654
+ - Images will produce customized visual art along with a generated prompt.
655
+ Discover your spirit animal and let your imagination run wild!
656
+ ---
657
+ """
658
+
659
+ with gr.Blocks() as demo:
660
+ gr.HTML(title_html)
661
+ gr.Markdown(description_text)
662
+
663
+ with gr.Tabs():
664
+ with gr.Tab("Generate Spirit Animal Image"):
665
+ gr.Markdown("Upload an image to generate a spirit animal.")
666
+ with gr.Row():
667
+ with gr.Column(scale=1):
668
+ image_input = gr.Image(type="filepath", label="Upload an image")
669
+ animal_type = gr.Radio(choices=["Single Animal", "Multiple Animals"], label="Animal Type", value="Single Animal")
670
+ background_option = gr.Radio(choices=["Preserve Background", "Don't Preserve Background"], label="Background Option", value="Preserve Background")
671
+ generate_image_button = gr.Button("Generate Image")
672
+ with gr.Column(scale=1):
673
+ generated_prompt = gr.Textbox(label="Generated Prompt")
674
+ generated_gallery = gr.Gallery(label="Generated Images")
675
+
676
+ generate_image_button.click(
677
+ fn=generate_spirit_animal,
678
+ inputs=[image_input, animal_type, background_option],
679
+ outputs=[generated_prompt, generated_gallery],
680
+ )
681
+
682
+ with gr.Tab("Generate Spirit Animal Video"):
683
+ gr.Markdown("Upload a driving video to generate a spirit animal video.")
684
+ with gr.Row():
685
+ with gr.Column(scale=1):
686
+ video_input = gr.Video(label="Upload a driving video (MP4 format)")
687
+ generate_video_button = gr.Button("Generate Video")
688
+ with gr.Column(scale=1):
689
+ video_output = gr.Video(label="Generated Spirit Animal Video")
690
+
691
+ generate_video_button.click(
692
+ fn=process_video,
693
+ inputs=video_input,
694
+ outputs=video_output,
695
+ )
696
+
697
+ demo.launch()