wzhouxiff commited on
Commit
1cc2bd8
1 Parent(s): 3a05432
Files changed (2) hide show
  1. app copy.py +0 -740
  2. app.py +1 -1
app copy.py DELETED
@@ -1,740 +0,0 @@
1
- try:
2
- import spaces
3
- except:
4
- pass
5
-
6
- import os
7
- import gradio as gr
8
-
9
- import torch
10
- from gradio_image_prompter import ImagePrompter
11
- from sam2.sam2_image_predictor import SAM2ImagePredictor
12
- from omegaconf import OmegaConf
13
- from PIL import Image
14
- import numpy as np
15
- from copy import deepcopy
16
- import cv2
17
-
18
- import torch.nn.functional as F
19
- import torchvision
20
- from einops import rearrange
21
- import tempfile
22
-
23
- from objctrl_2_5d.utils.ui_utils import process_image, get_camera_pose, get_subject_points, get_points, undo_points, mask_image
24
- from ZoeDepth.zoedepth.utils.misc import colorize
25
-
26
- from cameractrl.inference import get_pipeline
27
- from objctrl_2_5d.utils.examples import examples, sync_points
28
-
29
- from objctrl_2_5d.utils.objmask_util import RT2Plucker, Unprojected, roll_with_ignore_multidim, dilate_mask_pytorch
30
- from objctrl_2_5d.utils.filter_utils import get_freq_filter, freq_mix_3d
31
-
32
-
33
- ### Title and Description ###
34
- #### Description ####
35
- title = r"""<h1 align="center">ObjCtrl-2.5D: Training-free Object Control with Camera Poses</h1>"""
36
- # subtitle = r"""<h2 align="center">Deployed on SVD Generation</h2>"""
37
- important_link = r"""
38
- <div align='center'>
39
- <a href='https://wzhouxiff.github.io/projects/MotionCtrl/assets/paper/MotionCtrl.pdf'>[Paper]</a>
40
- &ensp; <a href='https://wzhouxiff.github.io/projects/MotionCtrl/'>[Project Page]</a>
41
- &ensp; <a href='https://github.com/TencentARC/MotionCtrl'>[Code]</a>
42
- </div>
43
- """
44
-
45
- authors = r"""
46
- <div align='center'>
47
- <a href='https://wzhouxiff.github.io/'>Zhouxia Wang</a>
48
- &ensp; <a href='https://nirvanalan.github.io/'>Yushi Lan</a>
49
- &ensp; <a href='https://shangchenzhou.com/'>Shanchen Zhou</a>
50
- &ensp; <a href='https://www.mmlab-ntu.com/person/ccloy/index.html'>Chen Change Loy</a>
51
- </div>
52
- """
53
-
54
- affiliation = r"""
55
- <div align='center'>
56
- <a href='https://www.mmlab-ntu.com/'>S-Lab, NTU Singapore</a>
57
- </div>
58
- """
59
-
60
- description = r"""
61
- <b>Official Gradio demo</b> for <a href='https://github.com/TencentARC/MotionCtrl' target='_blank'><b>ObjCtrl-2.5D: Training-free Object Control with Camera Poses</b></a>.<br>
62
- 🔥 ObjCtrl2.5D enables object motion control in a I2V generated video via transforming 2D trajectories to 3D using depth, subsequently converting them into camera poses,
63
- thereby leveraging the exisitng camera motion control module for object motion control without requiring additional training.<br>
64
- """
65
-
66
- article = r"""
67
- If ObjCtrl2.5D is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/MotionCtrl' target='_blank'>Github Repo</a>. Thanks!
68
- [![GitHub Stars](https://img.shields.io/github/stars/TencentARC%2FMotionCtrl
69
- )](https://github.com/TencentARC/MotionCtrl)
70
-
71
- ---
72
-
73
- 📝 **Citation**
74
- <br>
75
- If our work is useful for your research, please consider citing:
76
- ```bibtex
77
- @inproceedings{wang2024motionctrl,
78
- title={Motionctrl: A unified and flexible motion controller for video generation},
79
- author={Wang, Zhouxia and Yuan, Ziyang and Wang, Xintao and Li, Yaowei and Chen, Tianshui and Xia, Menghan and Luo, Ping and Shan, Ying},
80
- booktitle={ACM SIGGRAPH 2024 Conference Papers},
81
- pages={1--11},
82
- year={2024}
83
- }
84
- ```
85
-
86
- 📧 **Contact**
87
- <br>
88
- If you have any questions, please feel free to reach me out at <b>zhouzi1212@gmail.com</b>.
89
-
90
- """
91
-
92
- # -------------- initialization --------------
93
-
94
- CAMERA_MODE = ["Traj2Cam", "Rotate", "Clockwise", "Translate"]
95
-
96
- # select the device for computation
97
- if torch.cuda.is_available():
98
- device = torch.device("cuda")
99
- elif torch.backends.mps.is_available():
100
- device = torch.device("mps")
101
- else:
102
- device = torch.device("cpu")
103
- print(f"using device: {device}")
104
-
105
- # segmentation model
106
- segmentor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny", cache_dir="ckpt", device=device)
107
-
108
- # depth model
109
- d_model_NK = torch.hub.load('./ZoeDepth', 'ZoeD_NK', source='local', pretrained=True).to(device)
110
-
111
- # cameractrl model
112
- config = "configs/svd_320_576_cameractrl.yaml"
113
- model_id = "stabilityai/stable-video-diffusion-img2vid"
114
- ckpt = "checkpoints/CameraCtrl_svd.ckpt"
115
- if not os.path.exists(ckpt):
116
- os.makedirs("checkpoints", exist_ok=True)
117
- os.system("wget -c https://huggingface.co/hehao13/CameraCtrl_SVD_ckpts/resolve/main/CameraCtrl_svd.ckpt?download=true")
118
- os.system("mv CameraCtrl_svd.ckpt?download=true checkpoints/CameraCtrl_svd.ckpt")
119
- model_config = OmegaConf.load(config)
120
-
121
-
122
- pipeline = get_pipeline(model_id, "unet", model_config['down_block_types'], model_config['up_block_types'],
123
- model_config['pose_encoder_kwargs'], model_config['attention_processor_kwargs'],
124
- ckpt, True, device)
125
-
126
- # segmentor = None
127
- # d_model_NK = None
128
- # pipeline = None
129
-
130
- ### run the demo ##
131
- # @spaces.GPU(duration=5)
132
- def segment(canvas, image, logits):
133
- if logits is not None:
134
- logits *= 32.0
135
- _, points = get_subject_points(canvas)
136
- image = np.array(image)
137
-
138
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
139
- segmentor.set_image(image)
140
- input_points = []
141
- input_boxes = []
142
- for p in points:
143
- [x1, y1, _, x2, y2, _] = p
144
- if x2==0 and y2==0:
145
- input_points.append([x1, y1])
146
- else:
147
- input_boxes.append([x1, y1, x2, y2])
148
- if len(input_points) == 0:
149
- input_points = None
150
- input_labels = None
151
- else:
152
- input_points = np.array(input_points)
153
- input_labels = np.ones(len(input_points))
154
- if len(input_boxes) == 0:
155
- input_boxes = None
156
- else:
157
- input_boxes = np.array(input_boxes)
158
- masks, _, logits = segmentor.predict(
159
- point_coords=input_points,
160
- point_labels=input_labels,
161
- box=input_boxes,
162
- multimask_output=False,
163
- return_logits=True,
164
- mask_input=logits,
165
- )
166
- mask = masks > 0
167
- masked_img = mask_image(image, mask[0], color=[252, 140, 90], alpha=0.9)
168
- masked_img = Image.fromarray(masked_img)
169
-
170
- return mask[0], masked_img, masked_img, logits / 32.0
171
-
172
- # @spaces.GPU(duration=5)
173
- def get_depth(image, points):
174
-
175
- depth = d_model_NK.infer_pil(image)
176
- colored_depth = colorize(depth, cmap='gray_r') # [h, w, 4] 0-255
177
-
178
- depth_img = deepcopy(colored_depth[:, :, :3])
179
- if len(points) > 0:
180
- for idx, point in enumerate(points):
181
- if idx % 2 == 0:
182
- cv2.circle(depth_img, tuple(point), 10, (255, 0, 0), -1)
183
- else:
184
- cv2.circle(depth_img, tuple(point), 10, (0, 0, 255), -1)
185
- if idx > 0:
186
- cv2.arrowedLine(depth_img, points[idx-1], points[idx], (255, 255, 255), 4, tipLength=0.5)
187
-
188
- return depth, depth_img, colored_depth[:, :, :3]
189
-
190
-
191
- # @spaces.GPU(duration=80)
192
- def run_objctrl_2_5d(condition_image,
193
- mask,
194
- depth,
195
- RTs,
196
- bg_mode,
197
- shared_wapring_latents,
198
- scale_wise_masks,
199
- rescale,
200
- seed,
201
- ds, dt,
202
- num_inference_steps=25):
203
-
204
- DEBUG = False
205
-
206
- if DEBUG:
207
- cur_OUTPUT_PATH = 'outputs/tmp'
208
- os.makedirs(cur_OUTPUT_PATH, exist_ok=True)
209
-
210
- # num_inference_steps=25
211
- min_guidance_scale = 1.0
212
- max_guidance_scale = 3.0
213
-
214
- area_ratio = 0.3
215
- depth_scale_ = 5.2
216
- center_margin = 10
217
-
218
- height, width = 320, 576
219
- num_frames = 14
220
-
221
- intrinsics = np.array([[float(width), float(width), float(width) / 2, float(height) / 2]])
222
- intrinsics = np.repeat(intrinsics, num_frames, axis=0) # [n_frame, 4]
223
- fx = intrinsics[0, 0] / width
224
- fy = intrinsics[0, 1] / height
225
- cx = intrinsics[0, 2] / width
226
- cy = intrinsics[0, 3] / height
227
-
228
- down_scale = 8
229
- H, W = height // down_scale, width // down_scale
230
- K = np.array([[width / down_scale, 0, W / 2], [0, width / down_scale, H / 2], [0, 0, 1]])
231
-
232
- seed = int(seed)
233
-
234
- center_h_margin, center_w_margin = center_margin, center_margin
235
- depth_center = np.mean(depth[height//2-center_h_margin:height//2+center_h_margin, width//2-center_w_margin:width//2+center_w_margin])
236
-
237
- if rescale > 0:
238
- depth_rescale = round(depth_scale_ * rescale / depth_center, 2)
239
- else:
240
- depth_rescale = 1.0
241
-
242
- depth = depth * depth_rescale
243
-
244
- depth_down = F.interpolate(torch.tensor(depth).unsqueeze(0).unsqueeze(0),
245
- (H, W), mode='bilinear', align_corners=False).squeeze().numpy() # [H, W]
246
-
247
- ## latent
248
- generator = torch.Generator()
249
- generator.manual_seed(seed)
250
-
251
- latents_org = pipeline.prepare_latents(
252
- 1,
253
- 14,
254
- 8,
255
- height,
256
- width,
257
- pipeline.dtype,
258
- device,
259
- generator,
260
- None,
261
- )
262
- latents_org = latents_org / pipeline.scheduler.init_noise_sigma
263
-
264
- cur_plucker_embedding, _, _ = RT2Plucker(RTs, RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W
265
- cur_plucker_embedding = cur_plucker_embedding.to(device)
266
- cur_plucker_embedding = cur_plucker_embedding[None, ...] # b 6 f h w
267
- cur_plucker_embedding = cur_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
268
- cur_plucker_embedding = cur_plucker_embedding[:, :num_frames, ...]
269
- cur_pose_features = pipeline.pose_encoder(cur_plucker_embedding)
270
-
271
- # bg_mode = ["Fixed", "Reverse", "Free"]
272
- if bg_mode == "Fixed":
273
- fix_RTs = np.repeat(RTs[0][None, ...], num_frames, axis=0) # [n_frame, 4, 3]
274
- fix_plucker_embedding, _, _ = RT2Plucker(fix_RTs, num_frames, (height, width), fx, fy, cx, cy) # 6, V, H, W
275
- fix_plucker_embedding = fix_plucker_embedding.to(device)
276
- fix_plucker_embedding = fix_plucker_embedding[None, ...] # b 6 f h w
277
- fix_plucker_embedding = fix_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
278
- fix_plucker_embedding = fix_plucker_embedding[:, :num_frames, ...]
279
- fix_pose_features = pipeline.pose_encoder(fix_plucker_embedding)
280
-
281
- elif bg_mode == "Reverse":
282
- bg_plucker_embedding, _, _ = RT2Plucker(RTs[::-1], RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W
283
- bg_plucker_embedding = bg_plucker_embedding.to(device)
284
- bg_plucker_embedding = bg_plucker_embedding[None, ...] # b 6 f h w
285
- bg_plucker_embedding = bg_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
286
- bg_plucker_embedding = bg_plucker_embedding[:, :num_frames, ...]
287
- fix_pose_features = pipeline.pose_encoder(bg_plucker_embedding)
288
-
289
- else:
290
- fix_pose_features = None
291
-
292
- #### preparing mask
293
-
294
- mask = Image.fromarray(mask)
295
- mask = mask.resize((W, H))
296
- mask = np.array(mask).astype(np.float32)
297
- mask = np.expand_dims(mask, axis=-1)
298
-
299
- # visulize mask
300
- if DEBUG:
301
- mask_sum_vis = mask[..., 0]
302
- mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8)
303
- mask_sum_vis = Image.fromarray(mask_sum_vis)
304
-
305
- mask_sum_vis.save(f'{cur_OUTPUT_PATH}/org_mask.png')
306
-
307
- try:
308
- warped_masks = Unprojected(mask, depth_down, RTs, H=H, W=W, K=K)
309
-
310
- warped_masks.insert(0, mask)
311
-
312
- except:
313
- # mask to bbox
314
- print(f'!!! Mask is too small to warp; mask to bbox')
315
- mask = mask[:, :, 0]
316
- coords = cv2.findNonZero(mask)
317
- x, y, w, h = cv2.boundingRect(coords)
318
- # mask[y:y+h, x:x+w] = 1.0
319
-
320
- center_x, center_y = x + w // 2, y + h // 2
321
- center_z = depth_down[center_y, center_x]
322
-
323
- # RTs [n_frame, 3, 4] to [n_frame, 4, 4] , add [0, 0, 0, 1]
324
- RTs = np.concatenate([RTs, np.array([[[0, 0, 0, 1]]] * num_frames)], axis=1)
325
-
326
- # RTs: world to camera
327
- P0 = np.array([center_x, center_y, 1])
328
- Pc0 = np.linalg.inv(K) @ P0 * center_z
329
- pw = np.linalg.inv(RTs[0]) @ np.array([Pc0[0], Pc0[1], center_z, 1]) # [4]
330
-
331
- P = [np.array([center_x, center_y])]
332
- for i in range(1, num_frames):
333
- Pci = RTs[i] @ pw
334
- Pi = K @ Pci[:3] / Pci[2]
335
- P.append(Pi[:2])
336
-
337
- warped_masks = [mask]
338
- for i in range(1, num_frames):
339
- shift_x = int(round(P[i][0] - P[0][0]))
340
- shift_y = int(round(P[i][1] - P[0][1]))
341
-
342
- cur_mask = roll_with_ignore_multidim(mask, [shift_y, shift_x])
343
- warped_masks.append(cur_mask)
344
-
345
-
346
- warped_masks = [v[..., None] for v in warped_masks]
347
-
348
- warped_masks = np.stack(warped_masks, axis=0) # [f, h, w]
349
- warped_masks = np.repeat(warped_masks, 3, axis=-1) # [f, h, w, 3]
350
-
351
- mask_sum = np.sum(warped_masks, axis=0, keepdims=True) # [1, H, W, 3]
352
- mask_sum[mask_sum > 1.0] = 1.0
353
- mask_sum = mask_sum[0,:,:, 0]
354
-
355
- if DEBUG:
356
- ## visulize warp mask
357
- warp_masks_vis = torch.tensor(warped_masks)
358
- warp_masks_vis = (warp_masks_vis * 255.0).to(torch.uint8)
359
- torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warped_masks.mp4', warp_masks_vis, fps=10, video_codec='h264', options={'crf': '10'})
360
-
361
- # visulize mask
362
- mask_sum_vis = mask_sum
363
- mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8)
364
- mask_sum_vis = Image.fromarray(mask_sum_vis)
365
-
366
- mask_sum_vis.save(f'{cur_OUTPUT_PATH}/merged_mask.png')
367
-
368
- if scale_wise_masks:
369
- min_area = H * W * area_ratio # cal in downscale
370
- non_zero_len = mask_sum.sum()
371
-
372
- print(f'non_zero_len: {non_zero_len}, min_area: {min_area}')
373
-
374
- if non_zero_len > min_area:
375
- kernel_sizes = [1, 1, 1, 3]
376
- elif non_zero_len > min_area * 0.5:
377
- kernel_sizes = [3, 1, 1, 5]
378
- else:
379
- kernel_sizes = [5, 3, 3, 7]
380
- else:
381
- kernel_sizes = [1, 1, 1, 1]
382
-
383
- mask = torch.from_numpy(mask_sum) # [h, w]
384
- mask = mask[None, None, ...] # [1, 1, h, w]
385
- mask = F.interpolate(mask, (height, width), mode='bilinear', align_corners=False) # [1, 1, H, W]
386
- # mask = mask.repeat(1, num_frames, 1, 1) # [1, f, H, W]
387
- mask = mask.to(pipeline.dtype).to(device)
388
-
389
- ##### Mask End ######
390
-
391
- ### Got blending pose features Start ###
392
-
393
- pose_features = []
394
- for i in range(0, len(cur_pose_features)):
395
- kernel_size = kernel_sizes[i]
396
- h, w = cur_pose_features[i].shape[-2:]
397
-
398
- if fix_pose_features is None:
399
- pose_features.append(torch.zeros_like(cur_pose_features[i]))
400
- else:
401
- pose_features.append(fix_pose_features[i])
402
-
403
- cur_mask = F.interpolate(mask, (h, w), mode='bilinear', align_corners=False)
404
- cur_mask = dilate_mask_pytorch(cur_mask, kernel_size=kernel_size) # [1, 1, H, W]
405
- cur_mask = cur_mask.repeat(1, num_frames, 1, 1) # [1, f, H, W]
406
-
407
- if DEBUG:
408
- # visulize mask
409
- mask_vis = cur_mask[0, 0].cpu().numpy() * 255.0
410
- mask_vis = Image.fromarray(mask_vis.astype(np.uint8))
411
- mask_vis.save(f'{cur_OUTPUT_PATH}/mask_k{kernel_size}_scale{i}.png')
412
-
413
- cur_mask = cur_mask[None, ...] # [1, 1, f, H, W]
414
- pose_features[-1] = cur_pose_features[i] * cur_mask + pose_features[-1] * (1 - cur_mask)
415
-
416
- ### Got blending pose features End ###
417
-
418
- ##### Warp Noise Start ######
419
-
420
- if shared_wapring_latents:
421
- noise = latents_org[0, 0].data.cpu().numpy().copy() #[14, 4, 40, 72]
422
- noise = np.transpose(noise, (1, 2, 0)) # [40, 72, 4]
423
-
424
- try:
425
- warp_noise = Unprojected(noise, depth_down, RTs, H=H, W=W, K=K)
426
- warp_noise.insert(0, noise)
427
- except:
428
- print(f'!!! Noise is too small to warp; mask to bbox')
429
-
430
- warp_noise = [noise]
431
- for i in range(1, num_frames):
432
- shift_x = int(round(P[i][0] - P[0][0]))
433
- shift_y = int(round(P[i][1] - P[0][1]))
434
-
435
- cur_noise= roll_with_ignore_multidim(noise, [shift_y, shift_x])
436
- warp_noise.append(cur_noise)
437
-
438
- warp_noise = np.stack(warp_noise, axis=0) # [f, h, w, 4]
439
-
440
- if DEBUG:
441
- ## visulize warp noise
442
- warp_noise_vis = torch.tensor(warp_noise)[..., :3] * torch.tensor(warped_masks)
443
- warp_noise_vis = (warp_noise_vis - warp_noise_vis.min()) / (warp_noise_vis.max() - warp_noise_vis.min())
444
- warp_noise_vis = (warp_noise_vis * 255.0).to(torch.uint8)
445
-
446
- torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warp_noise.mp4', warp_noise_vis, fps=10, video_codec='h264', options={'crf': '10'})
447
-
448
-
449
- warp_latents = torch.tensor(warp_noise).permute(0, 3, 1, 2).to(latents_org.device).to(latents_org.dtype) # [frame, 4, H, W]
450
- warp_latents = warp_latents.unsqueeze(0) # [1, frame, 4, H, W]
451
-
452
- warped_masks = torch.tensor(warped_masks).permute(0, 3, 1, 2).unsqueeze(0) # [1, frame, 3, H, W]
453
- mask_extend = torch.concat([warped_masks, warped_masks[:,:,0:1]], dim=2) # [1, frame, 4, H, W]
454
- mask_extend = mask_extend.to(latents_org.device).to(latents_org.dtype)
455
-
456
- warp_latents = warp_latents * mask_extend + latents_org * (1 - mask_extend)
457
- warp_latents = warp_latents.permute(0, 2, 1, 3, 4)
458
- random_noise = latents_org.clone().permute(0, 2, 1, 3, 4)
459
-
460
- filter_shape = warp_latents.shape
461
-
462
- freq_filter = get_freq_filter(
463
- filter_shape,
464
- device = device,
465
- filter_type='butterworth',
466
- n=4,
467
- d_s=ds,
468
- d_t=dt
469
- )
470
-
471
- warp_latents = freq_mix_3d(warp_latents, random_noise, freq_filter)
472
- warp_latents = warp_latents.permute(0, 2, 1, 3, 4)
473
-
474
- else:
475
- warp_latents = latents_org.clone()
476
-
477
- generator.manual_seed(42)
478
-
479
- with torch.no_grad():
480
- result = pipeline(
481
- image=condition_image,
482
- pose_embedding=cur_plucker_embedding,
483
- height=height,
484
- width=width,
485
- num_frames=num_frames,
486
- num_inference_steps=num_inference_steps,
487
- min_guidance_scale=min_guidance_scale,
488
- max_guidance_scale=max_guidance_scale,
489
- do_image_process=True,
490
- generator=generator,
491
- output_type='pt',
492
- pose_features= pose_features,
493
- latents = warp_latents
494
- ).frames[0].cpu() #[f, c, h, w]
495
-
496
-
497
- result = rearrange(result, 'f c h w -> f h w c')
498
- result = (result * 255.0).to(torch.uint8)
499
-
500
- video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
501
- torchvision.io.write_video(video_path, result, fps=10, video_codec='h264', options={'crf': '8'})
502
-
503
- return video_path
504
-
505
- # -------------- UI definition --------------
506
- with gr.Blocks() as demo:
507
- # layout definition
508
- gr.Markdown(title)
509
- gr.Markdown(authors)
510
- gr.Markdown(affiliation)
511
- gr.Markdown(important_link)
512
- gr.Markdown(description)
513
-
514
-
515
- # with gr.Row():
516
- # gr.Markdown("""# <center>Repositioning the Subject within Image </center>""")
517
- mask = gr.State(value=None) # store mask
518
- removal_mask = gr.State(value=None) # store removal mask
519
- selected_points = gr.State([]) # store points
520
- selected_points_text = gr.Textbox(label="Selected Points", visible=False)
521
-
522
- original_image = gr.State(value=None) # store original input image
523
- masked_original_image = gr.State(value=None) # store masked input image
524
- mask_logits = gr.State(value=None) # store mask logits
525
-
526
- depth = gr.State(value=None) # store depth
527
- org_depth_image = gr.State(value=None) # store original depth image
528
-
529
- camera_pose = gr.State(value=None) # store camera pose
530
-
531
- with gr.Column():
532
-
533
- outlines = """
534
- <font size="5"><b>There are total 5 steps to complete the task.</b></font>
535
- - Step 1: Input an image and Crop it to a suitable size;
536
- - Step 2: Attain the subject mask;
537
- - Step 3: Get depth and Draw Trajectory;
538
- - Step 4: Get camera pose from trajectory or customize it;
539
- - Step 5: Generate the final video.
540
- """
541
-
542
- gr.Markdown(outlines)
543
-
544
-
545
- with gr.Row():
546
- with gr.Column():
547
- # Step 1: Input Image
548
- step1_dec = """
549
- <font size="4"><b>Step 1: Input Image</b></font>
550
- - Select the region using a <mark>bounding box</mark>, aiming for a ratio close to </mark>320:576</mark> (height:width).
551
- - All provided images in `Examples` are in 320 x 576 resolution. Simply press `Process` to proceed.
552
- """
553
- step1 = gr.Markdown(step1_dec)
554
- raw_input = ImagePrompter(type="pil", label="Raw Image", show_label=True, interactive=True)
555
- # left_up_point = gr.Textbox(value = "-1 -1", label="Left Up Point", interactive=True)
556
- process_button = gr.Button("Process")
557
-
558
- with gr.Column():
559
- # Step 2: Get Subject Mask
560
- step2_dec = """
561
- <font size="4"><b>Step 2: Get Subject Mask</b></font>
562
- - Use the <mark>bounding boxes</mark> or <mark>paints</mark> to select the subject.
563
- - Press `Segment Subject` to get the mask. <mark>Can be refined iteratively by updating points<mark>.
564
- """
565
- step2 = gr.Markdown(step2_dec)
566
- canvas = ImagePrompter(type="pil", label="Input Image", show_label=True, interactive=True) # for mask painting
567
-
568
- select_button = gr.Button("Segment Subject")
569
-
570
- with gr.Row():
571
- with gr.Column():
572
- mask_dec = """
573
- <font size="4"><b>Mask Result</b></font>
574
- - Just for visualization purpose. No need to interact.
575
- """
576
- mask_vis = gr.Markdown(mask_dec)
577
- mask_output = gr.Image(type="pil", label="Mask", show_label=True, interactive=False)
578
- with gr.Column():
579
- # Step 3: Get Depth and Draw Trajectory
580
- step3_dec = """
581
- <font size="4"><b>Step 3: Get Depth and Draw Trajectory</b></font>
582
- - Press `Get Depth` to get the depth image.
583
- - Draw the trajectory by selecting points on the depth image. <mark>No more than 14 points</mark>.
584
- - Press `Undo point` to remove all points.
585
- """
586
- step3 = gr.Markdown(step3_dec)
587
- depth_image = gr.Image(type="pil", label="Depth Image", show_label=True, interactive=False)
588
- with gr.Row():
589
- depth_button = gr.Button("Get Depth")
590
- undo_button = gr.Button("Undo point")
591
-
592
- with gr.Row():
593
- with gr.Column():
594
- # Step 4: Trajectory to Camera Pose or Get Camera Pose
595
- step4_dec = """
596
- <font size="4"><b>Step 4: Get camera pose from trajectory or customize it</b></font>
597
- - Option 1: Transform the 2D trajectory to camera poses with depth. <mark>`Rescale` is used for depth alignment. Larger value can speed up the object motion.</mark>
598
- - Option 2: Rotate the camera with a specific `Angle`.
599
- - Option 3: Rotate the camera clockwise or counterclockwise with a specific `Angle`.
600
- - Option 4: Translate the camera with `Tx` (<mark>Pan Left/Right</mark>), `Ty` (<mark>Pan Up/Down</mark>), `Tz` (<mark>Zoom In/Out</mark>) and `Speed`.
601
- """
602
- step4 = gr.Markdown(step4_dec)
603
- camera_pose_vis = gr.Plot(None, label='Camera Pose')
604
- with gr.Row():
605
- with gr.Column():
606
- speed = gr.Slider(minimum=0.1, maximum=10, step=0.1, value=1.0, label="Speed", interactive=True)
607
- rescale = gr.Slider(minimum=0.0, maximum=10, step=0.1, value=1.0, label="Rescale", interactive=True)
608
- # traj2pose_button = gr.Button("Option1: Trajectory to Camera Pose")
609
-
610
- angle = gr.Slider(minimum=-360, maximum=360, step=1, value=60, label="Angle", interactive=True)
611
- # rotation_button = gr.Button("Option2: Rotate")
612
- # clockwise_button = gr.Button("Option3: Clockwise")
613
- with gr.Column():
614
-
615
- Tx = gr.Slider(minimum=-1, maximum=1, step=1, value=0, label="Tx", interactive=True)
616
- Ty = gr.Slider(minimum=-1, maximum=1, step=1, value=0, label="Ty", interactive=True)
617
- Tz = gr.Slider(minimum=-1, maximum=1, step=1, value=0, label="Tz", interactive=True)
618
- # translation_button = gr.Button("Option4: Translate")
619
- with gr.Row():
620
- camera_option = gr.Radio(choices = CAMERA_MODE, label='Camera Options', value=CAMERA_MODE[0], interactive=True)
621
- with gr.Row():
622
- get_camera_pose_button = gr.Button("Get Camera Pose")
623
-
624
- with gr.Column():
625
- # Step 5: Get the final generated video
626
- step5_dec = """
627
- <font size="4"><b>Step 5: Get the final generated video</b></font>
628
- - 3 modes for background: <mark>Fixed</mark>, <mark>Reverse</mark>, <mark>Free</mark>.
629
- - Enable <mark>Scale-wise Masks</mark> for better object control.
630
- - Option to enable <mark>Shared Warping Latents</mark> and set <mark>stop frequency</mark> for spatial (`ds`) and temporal (`dt`) dimensions. Larger stop frequency will lead to artifacts.
631
- """
632
- step5 = gr.Markdown(step5_dec)
633
- generated_video = gr.Video(None, label='Generated Video')
634
-
635
- with gr.Row():
636
- seed = gr.Textbox(value = "42", label="Seed", interactive=True)
637
- # num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, value=25, label="Number of Inference Steps", interactive=True)
638
- bg_mode = gr.Radio(choices = ["Fixed", "Reverse", "Free"], label="Background Mode", value="Fixed", interactive=True)
639
- # swl_mode = gr.Radio(choices = ["Enable SWL", "Disable SWL"], label="Shared Warping Latent", value="Disable SWL", interactive=True)
640
- scale_wise_masks = gr.Checkbox(label="Enable Scale-wise Masks", interactive=True, value=True)
641
- with gr.Row():
642
- with gr.Column():
643
- shared_wapring_latents = gr.Checkbox(label="Enable Shared Warping Latents", interactive=True)
644
- with gr.Column():
645
- ds = gr.Slider(minimum=0.0, maximum=1, step=0.1, value=0.5, label="ds", interactive=True)
646
- dt = gr.Slider(minimum=0.0, maximum=1, step=0.1, value=0.5, label="dt", interactive=True)
647
-
648
- generated_button = gr.Button("Generate")
649
-
650
-
651
-
652
- # # event definition
653
- process_button.click(
654
- fn = process_image,
655
- inputs = [raw_input],
656
- outputs = [original_image, canvas]
657
- )
658
-
659
- select_button.click(
660
- segment,
661
- [canvas, original_image, mask_logits],
662
- [mask, mask_output, masked_original_image, mask_logits]
663
- )
664
-
665
- depth_button.click(
666
- get_depth,
667
- [original_image, selected_points],
668
- [depth, depth_image, org_depth_image]
669
- )
670
-
671
- depth_image.select(
672
- get_points,
673
- [depth_image, selected_points],
674
- [depth_image, selected_points],
675
- )
676
- undo_button.click(
677
- undo_points,
678
- [org_depth_image],
679
- [depth_image, selected_points]
680
- )
681
-
682
- get_camera_pose_button.click(
683
- get_camera_pose(CAMERA_MODE),
684
- [camera_option, selected_points, depth, mask, rescale, angle, Tx, Ty, Tz, speed],
685
- [camera_pose, camera_pose_vis, rescale]
686
- )
687
-
688
- generated_button.click(
689
- run_objctrl_2_5d,
690
- [
691
- original_image,
692
- mask,
693
- depth,
694
- camera_pose,
695
- bg_mode,
696
- shared_wapring_latents,
697
- scale_wise_masks,
698
- rescale,
699
- seed,
700
- ds,
701
- dt,
702
- # num_inference_steps
703
- ],
704
- [generated_video],
705
- )
706
-
707
- gr.Examples(
708
- examples=examples,
709
- inputs=[
710
- raw_input,
711
- rescale,
712
- speed,
713
- angle,
714
- Tx,
715
- Ty,
716
- Tz,
717
- camera_option,
718
- bg_mode,
719
- shared_wapring_latents,
720
- scale_wise_masks,
721
- ds,
722
- dt,
723
- seed,
724
- selected_points_text # selected_points
725
- ],
726
- outputs=[generated_video],
727
- examples_per_page=10
728
- )
729
-
730
- selected_points_text.change(
731
- sync_points,
732
- inputs=[selected_points_text],
733
- outputs=[selected_points]
734
- )
735
-
736
-
737
- gr.Markdown(article)
738
-
739
-
740
- demo.queue().launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -771,7 +771,7 @@ with gr.Blocks() as demo:
771
  ## Get examples
772
  with open('./assets/examples/examples.json', 'r') as f:
773
  examples = json.load(f)
774
- print(examples)
775
 
776
  # examples = [examples]
777
  examples = [v for k, v in examples.items()]
 
771
  ## Get examples
772
  with open('./assets/examples/examples.json', 'r') as f:
773
  examples = json.load(f)
774
+ # print(examples)
775
 
776
  # examples = [examples]
777
  examples = [v for k, v in examples.items()]