svjack commited on
Commit
4c91eed
1 Parent(s): 304245e

Create select_image_app.py

Browse files
Files changed (1) hide show
  1. select_image_app.py +356 -0
select_image_app.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from datetime import datetime
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import AutoencoderKL, DDIMScheduler
9
+ from einops import repeat
10
+ from huggingface_hub import hf_hub_download, snapshot_download
11
+ from omegaconf import OmegaConf
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+ from transformers import CLIPVisionModelWithProjection
15
+
16
+ from src.models.pose_guider import PoseGuider
17
+ from src.models.unet_2d_condition import UNet2DConditionModel
18
+ from src.models.unet_3d import UNet3DConditionModel
19
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
20
+ from src.utils.download_models import prepare_base_model, prepare_image_encoder
21
+ from src.utils.util import get_fps, read_frames, save_videos_grid
22
+
23
+ # Partial download
24
+ prepare_base_model()
25
+ prepare_image_encoder()
26
+
27
+ snapshot_download(
28
+ repo_id="stabilityai/sd-vae-ft-mse", local_dir="./pretrained_weights/sd-vae-ft-mse"
29
+ )
30
+ snapshot_download(
31
+ repo_id="patrolli/AnimateAnyone",
32
+ local_dir="./pretrained_weights",
33
+ )
34
+
35
+
36
+ class AnimateController:
37
+ def __init__(
38
+ self,
39
+ config_path="./configs/prompts/animation.yaml",
40
+ weight_dtype=torch.float16,
41
+ ):
42
+ # Read pretrained weights path from config
43
+ self.config = OmegaConf.load(config_path)
44
+ self.pipeline = None
45
+ self.weight_dtype = weight_dtype
46
+
47
+ def animate(
48
+ self,
49
+ ref_image,
50
+ pose_video_path,
51
+ width=512,
52
+ height=768,
53
+ length=24,
54
+ num_inference_steps=25,
55
+ cfg=3.5,
56
+ seed=123,
57
+ ):
58
+ generator = torch.manual_seed(seed)
59
+ if isinstance(ref_image, np.ndarray):
60
+ ref_image = Image.fromarray(ref_image)
61
+ if self.pipeline is None:
62
+ vae = AutoencoderKL.from_pretrained(
63
+ self.config.pretrained_vae_path,
64
+ ).to("cuda", dtype=self.weight_dtype)
65
+
66
+ reference_unet = UNet2DConditionModel.from_pretrained(
67
+ self.config.pretrained_base_model_path,
68
+ subfolder="unet",
69
+ ).to(dtype=self.weight_dtype, device="cuda")
70
+
71
+ inference_config_path = self.config.inference_config
72
+ infer_config = OmegaConf.load(inference_config_path)
73
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
74
+ self.config.pretrained_base_model_path,
75
+ self.config.motion_module_path,
76
+ subfolder="unet",
77
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
78
+ ).to(dtype=self.weight_dtype, device="cuda")
79
+
80
+ pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
81
+ dtype=self.weight_dtype, device="cuda"
82
+ )
83
+
84
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
85
+ self.config.image_encoder_path
86
+ ).to(dtype=self.weight_dtype, device="cuda")
87
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
88
+ scheduler = DDIMScheduler(**sched_kwargs)
89
+
90
+ # load pretrained weights
91
+ denoising_unet.load_state_dict(
92
+ torch.load(self.config.denoising_unet_path, map_location="cpu"),
93
+ strict=False,
94
+ )
95
+ reference_unet.load_state_dict(
96
+ torch.load(self.config.reference_unet_path, map_location="cpu"),
97
+ )
98
+ pose_guider.load_state_dict(
99
+ torch.load(self.config.pose_guider_path, map_location="cpu"),
100
+ )
101
+
102
+ pipe = Pose2VideoPipeline(
103
+ vae=vae,
104
+ image_encoder=image_enc,
105
+ reference_unet=reference_unet,
106
+ denoising_unet=denoising_unet,
107
+ pose_guider=pose_guider,
108
+ scheduler=scheduler,
109
+ )
110
+ pipe = pipe.to("cuda", dtype=self.weight_dtype)
111
+ self.pipeline = pipe
112
+
113
+ pose_images = read_frames(pose_video_path)
114
+ src_fps = get_fps(pose_video_path)
115
+
116
+ pose_list = []
117
+ total_length = min(length, len(pose_images))
118
+ for pose_image_pil in pose_images[:total_length]:
119
+ pose_list.append(pose_image_pil)
120
+
121
+ video = self.pipeline(
122
+ ref_image,
123
+ pose_list,
124
+ width=width,
125
+ height=height,
126
+ video_length=total_length,
127
+ num_inference_steps=num_inference_steps,
128
+ guidance_scale=cfg,
129
+ generator=generator,
130
+ ).videos
131
+
132
+ new_h, new_w = video.shape[-2:]
133
+ pose_transform = transforms.Compose(
134
+ [transforms.Resize((new_h, new_w)), transforms.ToTensor()]
135
+ )
136
+ pose_tensor_list = []
137
+ for pose_image_pil in pose_images[:total_length]:
138
+ pose_tensor_list.append(pose_transform(pose_image_pil))
139
+
140
+ ref_image_tensor = pose_transform(ref_image) # (c, h, w)
141
+ ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
142
+ ref_image_tensor = repeat(
143
+ ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=total_length
144
+ )
145
+ pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
146
+ pose_tensor = pose_tensor.transpose(0, 1)
147
+ pose_tensor = pose_tensor.unsqueeze(0)
148
+ video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)
149
+
150
+ save_dir = f"./output/gradio"
151
+ if not os.path.exists(save_dir):
152
+ os.makedirs(save_dir, exist_ok=True)
153
+ date_str = datetime.now().strftime("%Y%m%d")
154
+ time_str = datetime.now().strftime("%H%M")
155
+ out_path = os.path.join(save_dir, f"{date_str}T{time_str}.mp4")
156
+ save_videos_grid(
157
+ video,
158
+ out_path,
159
+ n_rows=3,
160
+ fps=src_fps,
161
+ )
162
+
163
+ torch.cuda.empty_cache()
164
+
165
+ return out_path
166
+
167
+
168
+ controller = AnimateController()
169
+
170
+
171
+ def ui():
172
+ from datasets import load_dataset
173
+ import io
174
+ from PIL import Image
175
+
176
+ # Load dataset and filter images
177
+ image_ds = load_dataset("svjack/Genshin-Impact-Item-Image")
178
+ image_df = image_ds["train"].to_pandas()
179
+ image_df = image_df[
180
+ image_df["tag"].map(
181
+ lambda x: "肖像" in x and "角色" in x
182
+ )
183
+ ]
184
+
185
+ def bytes_to_pil_image(byte_data):
186
+ """
187
+ Convert a byte array to a PIL Image.
188
+
189
+ :param byte_data: A byte array containing image data.
190
+ :return: A PIL Image object.
191
+ """
192
+ # Create a BytesIO object from the byte data
193
+ image_stream = io.BytesIO(byte_data)
194
+
195
+ # Open the image using PIL
196
+ pil_image = Image.open(image_stream)
197
+
198
+ return pil_image
199
+
200
+ image_df["image"] = image_df["image"].map(lambda x: bytes_to_pil_image(x["bytes"]))
201
+
202
+ with gr.Blocks() as demo:
203
+ gr.HTML(
204
+ """
205
+ <h1 style="color:#dc5b1c;text-align:center">
206
+ Moore-AnimateAnyone Gradio Demo
207
+ </h1>
208
+ <div style="text-align:center">
209
+ <div style="display: inline-block; text-align: left;">
210
+ <p> This is a quick preview demo of Moore-AnimateAnyone. We appreciate the assistance provided by the HuggingFace team in setting up this demo. </p>
211
+ <p> If you like this project, please consider giving a star on <a herf="https://github.com/MooreThreads/Moore-AnimateAnyone"> our GitHub repo </a> 🤗. </p>
212
+ </div>
213
+ </div>
214
+ """
215
+ )
216
+
217
+ # Add Gallery for selecting images
218
+ with gr.Row():
219
+ gallery = gr.Gallery(
220
+ image_df["image"].tolist(),
221
+ label="Select Reference Image",
222
+ show_label=True,
223
+ elem_id="gallery",
224
+ columns=[2, 3, 4, 5, 6, 6], # Number of columns for different screen sizes
225
+ rows=[2, 2, 2, 2, 2, 2], # Number of rows for different screen sizes
226
+ height="400px", # Height of the gallery
227
+ object_fit="contain", # How images should be fit in the grid
228
+ )
229
+
230
+ with gr.Row():
231
+ reference_image = gr.Image(label="Reference Image")
232
+ motion_sequence = gr.Video(
233
+ format="mp4", label="Motion Sequence", height=512
234
+ )
235
+
236
+ with gr.Column():
237
+ width_slider = gr.Slider(
238
+ label="Width", minimum=448, maximum=768, value=512, step=64
239
+ )
240
+ height_slider = gr.Slider(
241
+ label="Height", minimum=512, maximum=960, value=768, step=64
242
+ )
243
+ length_slider = gr.Slider(
244
+ label="Video Length", minimum=24, maximum=128, value=72, step=24
245
+ )
246
+ with gr.Row():
247
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
248
+ seed_button = gr.Button(
249
+ value="\U0001F3B2", elem_classes="toolbutton"
250
+ )
251
+ seed_button.click(
252
+ fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)),
253
+ inputs=[],
254
+ outputs=[seed_textbox],
255
+ )
256
+ with gr.Row():
257
+ sampling_steps = gr.Slider(
258
+ label="Sampling steps",
259
+ value=15,
260
+ info="default: 15",
261
+ step=5,
262
+ maximum=20,
263
+ minimum=10,
264
+ )
265
+ guidance_scale = gr.Slider(
266
+ label="Guidance scale",
267
+ value=3.5,
268
+ info="default: 3.5",
269
+ step=0.5,
270
+ maximum=6.5,
271
+ minimum=2.0,
272
+ )
273
+ submit = gr.Button("Animate")
274
+
275
+ # Populate gallery with images from the dataset
276
+ # gallery.update(value=image_df["image"].tolist())
277
+ with gr.Row():
278
+ animation = gr.Video(
279
+ format="mp4",
280
+ label="Animation Results",
281
+ height=448,
282
+ autoplay=True,
283
+ )
284
+
285
+ def read_video(video):
286
+ return video
287
+
288
+ def read_image(image):
289
+ return Image.fromarray(image)
290
+
291
+ def select_image(selection: gr.SelectData):
292
+ print(selection.value['image'])
293
+ return selection.value['image']["path"]
294
+
295
+ # when user uploads a new video
296
+ motion_sequence.upload(
297
+ read_video, motion_sequence, motion_sequence, queue=False
298
+ )
299
+ # when `first_frame` is updated
300
+ reference_image.upload(
301
+ read_image, reference_image, reference_image, queue=False
302
+ )
303
+ # when the `submit` button is clicked
304
+ submit.click(
305
+ controller.animate,
306
+ [
307
+ reference_image,
308
+ motion_sequence,
309
+ width_slider,
310
+ height_slider,
311
+ length_slider,
312
+ sampling_steps,
313
+ guidance_scale,
314
+ seed_textbox,
315
+ ],
316
+ animation,
317
+ )
318
+
319
+ gallery.select(fn=select_image, inputs=None, outputs=[reference_image])
320
+
321
+ # Examples
322
+ gr.Markdown("## Examples")
323
+ gr.Examples(
324
+ examples=[
325
+ [
326
+ "./configs/inference/ref_images/anyone-5.png",
327
+ "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
328
+ 512,
329
+ 768,
330
+ 72,
331
+ ],
332
+ [
333
+ "./configs/inference/ref_images/anyone-10.png",
334
+ "./configs/inference/pose_videos/anyone-video-1_kps.mp4",
335
+ 512,
336
+ 768,
337
+ 72,
338
+ ],
339
+ [
340
+ "./configs/inference/ref_images/anyone-2.png",
341
+ "./configs/inference/pose_videos/anyone-video-5_kps.mp4",
342
+ 512,
343
+ 768,
344
+ 72,
345
+ ],
346
+ ],
347
+ inputs=[reference_image, motion_sequence, width_slider, height_slider, length_slider],
348
+ outputs=animation,
349
+ )
350
+
351
+ return demo
352
+
353
+
354
+ demo = ui()
355
+ demo.queue(max_size=10)
356
+ demo.launch(share=True, show_api=False)