wzhouxiff commited on
Commit
399e1c1
1 Parent(s): d5f3775

refined gradio

Browse files
Files changed (3) hide show
  1. app copy.py +740 -0
  2. app.py +222 -151
  3. objctrl_2_5d/utils/ui_utils.py +108 -22
app copy.py ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import spaces
3
+ except:
4
+ pass
5
+
6
+ import os
7
+ import gradio as gr
8
+
9
+ import torch
10
+ from gradio_image_prompter import ImagePrompter
11
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
12
+ from omegaconf import OmegaConf
13
+ from PIL import Image
14
+ import numpy as np
15
+ from copy import deepcopy
16
+ import cv2
17
+
18
+ import torch.nn.functional as F
19
+ import torchvision
20
+ from einops import rearrange
21
+ import tempfile
22
+
23
+ from objctrl_2_5d.utils.ui_utils import process_image, get_camera_pose, get_subject_points, get_points, undo_points, mask_image
24
+ from ZoeDepth.zoedepth.utils.misc import colorize
25
+
26
+ from cameractrl.inference import get_pipeline
27
+ from objctrl_2_5d.utils.examples import examples, sync_points
28
+
29
+ from objctrl_2_5d.utils.objmask_util import RT2Plucker, Unprojected, roll_with_ignore_multidim, dilate_mask_pytorch
30
+ from objctrl_2_5d.utils.filter_utils import get_freq_filter, freq_mix_3d
31
+
32
+
33
+ ### Title and Description ###
34
+ #### Description ####
35
+ title = r"""<h1 align="center">ObjCtrl-2.5D: Training-free Object Control with Camera Poses</h1>"""
36
+ # subtitle = r"""<h2 align="center">Deployed on SVD Generation</h2>"""
37
+ important_link = r"""
38
+ <div align='center'>
39
+ <a href='https://wzhouxiff.github.io/projects/MotionCtrl/assets/paper/MotionCtrl.pdf'>[Paper]</a>
40
+ &ensp; <a href='https://wzhouxiff.github.io/projects/MotionCtrl/'>[Project Page]</a>
41
+ &ensp; <a href='https://github.com/TencentARC/MotionCtrl'>[Code]</a>
42
+ </div>
43
+ """
44
+
45
+ authors = r"""
46
+ <div align='center'>
47
+ <a href='https://wzhouxiff.github.io/'>Zhouxia Wang</a>
48
+ &ensp; <a href='https://nirvanalan.github.io/'>Yushi Lan</a>
49
+ &ensp; <a href='https://shangchenzhou.com/'>Shanchen Zhou</a>
50
+ &ensp; <a href='https://www.mmlab-ntu.com/person/ccloy/index.html'>Chen Change Loy</a>
51
+ </div>
52
+ """
53
+
54
+ affiliation = r"""
55
+ <div align='center'>
56
+ <a href='https://www.mmlab-ntu.com/'>S-Lab, NTU Singapore</a>
57
+ </div>
58
+ """
59
+
60
+ description = r"""
61
+ <b>Official Gradio demo</b> for <a href='https://github.com/TencentARC/MotionCtrl' target='_blank'><b>ObjCtrl-2.5D: Training-free Object Control with Camera Poses</b></a>.<br>
62
+ 🔥 ObjCtrl2.5D enables object motion control in a I2V generated video via transforming 2D trajectories to 3D using depth, subsequently converting them into camera poses,
63
+ thereby leveraging the exisitng camera motion control module for object motion control without requiring additional training.<br>
64
+ """
65
+
66
+ article = r"""
67
+ If ObjCtrl2.5D is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/MotionCtrl' target='_blank'>Github Repo</a>. Thanks!
68
+ [![GitHub Stars](https://img.shields.io/github/stars/TencentARC%2FMotionCtrl
69
+ )](https://github.com/TencentARC/MotionCtrl)
70
+
71
+ ---
72
+
73
+ 📝 **Citation**
74
+ <br>
75
+ If our work is useful for your research, please consider citing:
76
+ ```bibtex
77
+ @inproceedings{wang2024motionctrl,
78
+ title={Motionctrl: A unified and flexible motion controller for video generation},
79
+ author={Wang, Zhouxia and Yuan, Ziyang and Wang, Xintao and Li, Yaowei and Chen, Tianshui and Xia, Menghan and Luo, Ping and Shan, Ying},
80
+ booktitle={ACM SIGGRAPH 2024 Conference Papers},
81
+ pages={1--11},
82
+ year={2024}
83
+ }
84
+ ```
85
+
86
+ 📧 **Contact**
87
+ <br>
88
+ If you have any questions, please feel free to reach me out at <b>zhouzi1212@gmail.com</b>.
89
+
90
+ """
91
+
92
+ # -------------- initialization --------------
93
+
94
+ CAMERA_MODE = ["Traj2Cam", "Rotate", "Clockwise", "Translate"]
95
+
96
+ # select the device for computation
97
+ if torch.cuda.is_available():
98
+ device = torch.device("cuda")
99
+ elif torch.backends.mps.is_available():
100
+ device = torch.device("mps")
101
+ else:
102
+ device = torch.device("cpu")
103
+ print(f"using device: {device}")
104
+
105
+ # segmentation model
106
+ segmentor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny", cache_dir="ckpt", device=device)
107
+
108
+ # depth model
109
+ d_model_NK = torch.hub.load('./ZoeDepth', 'ZoeD_NK', source='local', pretrained=True).to(device)
110
+
111
+ # cameractrl model
112
+ config = "configs/svd_320_576_cameractrl.yaml"
113
+ model_id = "stabilityai/stable-video-diffusion-img2vid"
114
+ ckpt = "checkpoints/CameraCtrl_svd.ckpt"
115
+ if not os.path.exists(ckpt):
116
+ os.makedirs("checkpoints", exist_ok=True)
117
+ os.system("wget -c https://huggingface.co/hehao13/CameraCtrl_SVD_ckpts/resolve/main/CameraCtrl_svd.ckpt?download=true")
118
+ os.system("mv CameraCtrl_svd.ckpt?download=true checkpoints/CameraCtrl_svd.ckpt")
119
+ model_config = OmegaConf.load(config)
120
+
121
+
122
+ pipeline = get_pipeline(model_id, "unet", model_config['down_block_types'], model_config['up_block_types'],
123
+ model_config['pose_encoder_kwargs'], model_config['attention_processor_kwargs'],
124
+ ckpt, True, device)
125
+
126
+ # segmentor = None
127
+ # d_model_NK = None
128
+ # pipeline = None
129
+
130
+ ### run the demo ##
131
+ # @spaces.GPU(duration=5)
132
+ def segment(canvas, image, logits):
133
+ if logits is not None:
134
+ logits *= 32.0
135
+ _, points = get_subject_points(canvas)
136
+ image = np.array(image)
137
+
138
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
139
+ segmentor.set_image(image)
140
+ input_points = []
141
+ input_boxes = []
142
+ for p in points:
143
+ [x1, y1, _, x2, y2, _] = p
144
+ if x2==0 and y2==0:
145
+ input_points.append([x1, y1])
146
+ else:
147
+ input_boxes.append([x1, y1, x2, y2])
148
+ if len(input_points) == 0:
149
+ input_points = None
150
+ input_labels = None
151
+ else:
152
+ input_points = np.array(input_points)
153
+ input_labels = np.ones(len(input_points))
154
+ if len(input_boxes) == 0:
155
+ input_boxes = None
156
+ else:
157
+ input_boxes = np.array(input_boxes)
158
+ masks, _, logits = segmentor.predict(
159
+ point_coords=input_points,
160
+ point_labels=input_labels,
161
+ box=input_boxes,
162
+ multimask_output=False,
163
+ return_logits=True,
164
+ mask_input=logits,
165
+ )
166
+ mask = masks > 0
167
+ masked_img = mask_image(image, mask[0], color=[252, 140, 90], alpha=0.9)
168
+ masked_img = Image.fromarray(masked_img)
169
+
170
+ return mask[0], masked_img, masked_img, logits / 32.0
171
+
172
+ # @spaces.GPU(duration=5)
173
+ def get_depth(image, points):
174
+
175
+ depth = d_model_NK.infer_pil(image)
176
+ colored_depth = colorize(depth, cmap='gray_r') # [h, w, 4] 0-255
177
+
178
+ depth_img = deepcopy(colored_depth[:, :, :3])
179
+ if len(points) > 0:
180
+ for idx, point in enumerate(points):
181
+ if idx % 2 == 0:
182
+ cv2.circle(depth_img, tuple(point), 10, (255, 0, 0), -1)
183
+ else:
184
+ cv2.circle(depth_img, tuple(point), 10, (0, 0, 255), -1)
185
+ if idx > 0:
186
+ cv2.arrowedLine(depth_img, points[idx-1], points[idx], (255, 255, 255), 4, tipLength=0.5)
187
+
188
+ return depth, depth_img, colored_depth[:, :, :3]
189
+
190
+
191
+ # @spaces.GPU(duration=80)
192
+ def run_objctrl_2_5d(condition_image,
193
+ mask,
194
+ depth,
195
+ RTs,
196
+ bg_mode,
197
+ shared_wapring_latents,
198
+ scale_wise_masks,
199
+ rescale,
200
+ seed,
201
+ ds, dt,
202
+ num_inference_steps=25):
203
+
204
+ DEBUG = False
205
+
206
+ if DEBUG:
207
+ cur_OUTPUT_PATH = 'outputs/tmp'
208
+ os.makedirs(cur_OUTPUT_PATH, exist_ok=True)
209
+
210
+ # num_inference_steps=25
211
+ min_guidance_scale = 1.0
212
+ max_guidance_scale = 3.0
213
+
214
+ area_ratio = 0.3
215
+ depth_scale_ = 5.2
216
+ center_margin = 10
217
+
218
+ height, width = 320, 576
219
+ num_frames = 14
220
+
221
+ intrinsics = np.array([[float(width), float(width), float(width) / 2, float(height) / 2]])
222
+ intrinsics = np.repeat(intrinsics, num_frames, axis=0) # [n_frame, 4]
223
+ fx = intrinsics[0, 0] / width
224
+ fy = intrinsics[0, 1] / height
225
+ cx = intrinsics[0, 2] / width
226
+ cy = intrinsics[0, 3] / height
227
+
228
+ down_scale = 8
229
+ H, W = height // down_scale, width // down_scale
230
+ K = np.array([[width / down_scale, 0, W / 2], [0, width / down_scale, H / 2], [0, 0, 1]])
231
+
232
+ seed = int(seed)
233
+
234
+ center_h_margin, center_w_margin = center_margin, center_margin
235
+ depth_center = np.mean(depth[height//2-center_h_margin:height//2+center_h_margin, width//2-center_w_margin:width//2+center_w_margin])
236
+
237
+ if rescale > 0:
238
+ depth_rescale = round(depth_scale_ * rescale / depth_center, 2)
239
+ else:
240
+ depth_rescale = 1.0
241
+
242
+ depth = depth * depth_rescale
243
+
244
+ depth_down = F.interpolate(torch.tensor(depth).unsqueeze(0).unsqueeze(0),
245
+ (H, W), mode='bilinear', align_corners=False).squeeze().numpy() # [H, W]
246
+
247
+ ## latent
248
+ generator = torch.Generator()
249
+ generator.manual_seed(seed)
250
+
251
+ latents_org = pipeline.prepare_latents(
252
+ 1,
253
+ 14,
254
+ 8,
255
+ height,
256
+ width,
257
+ pipeline.dtype,
258
+ device,
259
+ generator,
260
+ None,
261
+ )
262
+ latents_org = latents_org / pipeline.scheduler.init_noise_sigma
263
+
264
+ cur_plucker_embedding, _, _ = RT2Plucker(RTs, RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W
265
+ cur_plucker_embedding = cur_plucker_embedding.to(device)
266
+ cur_plucker_embedding = cur_plucker_embedding[None, ...] # b 6 f h w
267
+ cur_plucker_embedding = cur_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
268
+ cur_plucker_embedding = cur_plucker_embedding[:, :num_frames, ...]
269
+ cur_pose_features = pipeline.pose_encoder(cur_plucker_embedding)
270
+
271
+ # bg_mode = ["Fixed", "Reverse", "Free"]
272
+ if bg_mode == "Fixed":
273
+ fix_RTs = np.repeat(RTs[0][None, ...], num_frames, axis=0) # [n_frame, 4, 3]
274
+ fix_plucker_embedding, _, _ = RT2Plucker(fix_RTs, num_frames, (height, width), fx, fy, cx, cy) # 6, V, H, W
275
+ fix_plucker_embedding = fix_plucker_embedding.to(device)
276
+ fix_plucker_embedding = fix_plucker_embedding[None, ...] # b 6 f h w
277
+ fix_plucker_embedding = fix_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
278
+ fix_plucker_embedding = fix_plucker_embedding[:, :num_frames, ...]
279
+ fix_pose_features = pipeline.pose_encoder(fix_plucker_embedding)
280
+
281
+ elif bg_mode == "Reverse":
282
+ bg_plucker_embedding, _, _ = RT2Plucker(RTs[::-1], RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W
283
+ bg_plucker_embedding = bg_plucker_embedding.to(device)
284
+ bg_plucker_embedding = bg_plucker_embedding[None, ...] # b 6 f h w
285
+ bg_plucker_embedding = bg_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
286
+ bg_plucker_embedding = bg_plucker_embedding[:, :num_frames, ...]
287
+ fix_pose_features = pipeline.pose_encoder(bg_plucker_embedding)
288
+
289
+ else:
290
+ fix_pose_features = None
291
+
292
+ #### preparing mask
293
+
294
+ mask = Image.fromarray(mask)
295
+ mask = mask.resize((W, H))
296
+ mask = np.array(mask).astype(np.float32)
297
+ mask = np.expand_dims(mask, axis=-1)
298
+
299
+ # visulize mask
300
+ if DEBUG:
301
+ mask_sum_vis = mask[..., 0]
302
+ mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8)
303
+ mask_sum_vis = Image.fromarray(mask_sum_vis)
304
+
305
+ mask_sum_vis.save(f'{cur_OUTPUT_PATH}/org_mask.png')
306
+
307
+ try:
308
+ warped_masks = Unprojected(mask, depth_down, RTs, H=H, W=W, K=K)
309
+
310
+ warped_masks.insert(0, mask)
311
+
312
+ except:
313
+ # mask to bbox
314
+ print(f'!!! Mask is too small to warp; mask to bbox')
315
+ mask = mask[:, :, 0]
316
+ coords = cv2.findNonZero(mask)
317
+ x, y, w, h = cv2.boundingRect(coords)
318
+ # mask[y:y+h, x:x+w] = 1.0
319
+
320
+ center_x, center_y = x + w // 2, y + h // 2
321
+ center_z = depth_down[center_y, center_x]
322
+
323
+ # RTs [n_frame, 3, 4] to [n_frame, 4, 4] , add [0, 0, 0, 1]
324
+ RTs = np.concatenate([RTs, np.array([[[0, 0, 0, 1]]] * num_frames)], axis=1)
325
+
326
+ # RTs: world to camera
327
+ P0 = np.array([center_x, center_y, 1])
328
+ Pc0 = np.linalg.inv(K) @ P0 * center_z
329
+ pw = np.linalg.inv(RTs[0]) @ np.array([Pc0[0], Pc0[1], center_z, 1]) # [4]
330
+
331
+ P = [np.array([center_x, center_y])]
332
+ for i in range(1, num_frames):
333
+ Pci = RTs[i] @ pw
334
+ Pi = K @ Pci[:3] / Pci[2]
335
+ P.append(Pi[:2])
336
+
337
+ warped_masks = [mask]
338
+ for i in range(1, num_frames):
339
+ shift_x = int(round(P[i][0] - P[0][0]))
340
+ shift_y = int(round(P[i][1] - P[0][1]))
341
+
342
+ cur_mask = roll_with_ignore_multidim(mask, [shift_y, shift_x])
343
+ warped_masks.append(cur_mask)
344
+
345
+
346
+ warped_masks = [v[..., None] for v in warped_masks]
347
+
348
+ warped_masks = np.stack(warped_masks, axis=0) # [f, h, w]
349
+ warped_masks = np.repeat(warped_masks, 3, axis=-1) # [f, h, w, 3]
350
+
351
+ mask_sum = np.sum(warped_masks, axis=0, keepdims=True) # [1, H, W, 3]
352
+ mask_sum[mask_sum > 1.0] = 1.0
353
+ mask_sum = mask_sum[0,:,:, 0]
354
+
355
+ if DEBUG:
356
+ ## visulize warp mask
357
+ warp_masks_vis = torch.tensor(warped_masks)
358
+ warp_masks_vis = (warp_masks_vis * 255.0).to(torch.uint8)
359
+ torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warped_masks.mp4', warp_masks_vis, fps=10, video_codec='h264', options={'crf': '10'})
360
+
361
+ # visulize mask
362
+ mask_sum_vis = mask_sum
363
+ mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8)
364
+ mask_sum_vis = Image.fromarray(mask_sum_vis)
365
+
366
+ mask_sum_vis.save(f'{cur_OUTPUT_PATH}/merged_mask.png')
367
+
368
+ if scale_wise_masks:
369
+ min_area = H * W * area_ratio # cal in downscale
370
+ non_zero_len = mask_sum.sum()
371
+
372
+ print(f'non_zero_len: {non_zero_len}, min_area: {min_area}')
373
+
374
+ if non_zero_len > min_area:
375
+ kernel_sizes = [1, 1, 1, 3]
376
+ elif non_zero_len > min_area * 0.5:
377
+ kernel_sizes = [3, 1, 1, 5]
378
+ else:
379
+ kernel_sizes = [5, 3, 3, 7]
380
+ else:
381
+ kernel_sizes = [1, 1, 1, 1]
382
+
383
+ mask = torch.from_numpy(mask_sum) # [h, w]
384
+ mask = mask[None, None, ...] # [1, 1, h, w]
385
+ mask = F.interpolate(mask, (height, width), mode='bilinear', align_corners=False) # [1, 1, H, W]
386
+ # mask = mask.repeat(1, num_frames, 1, 1) # [1, f, H, W]
387
+ mask = mask.to(pipeline.dtype).to(device)
388
+
389
+ ##### Mask End ######
390
+
391
+ ### Got blending pose features Start ###
392
+
393
+ pose_features = []
394
+ for i in range(0, len(cur_pose_features)):
395
+ kernel_size = kernel_sizes[i]
396
+ h, w = cur_pose_features[i].shape[-2:]
397
+
398
+ if fix_pose_features is None:
399
+ pose_features.append(torch.zeros_like(cur_pose_features[i]))
400
+ else:
401
+ pose_features.append(fix_pose_features[i])
402
+
403
+ cur_mask = F.interpolate(mask, (h, w), mode='bilinear', align_corners=False)
404
+ cur_mask = dilate_mask_pytorch(cur_mask, kernel_size=kernel_size) # [1, 1, H, W]
405
+ cur_mask = cur_mask.repeat(1, num_frames, 1, 1) # [1, f, H, W]
406
+
407
+ if DEBUG:
408
+ # visulize mask
409
+ mask_vis = cur_mask[0, 0].cpu().numpy() * 255.0
410
+ mask_vis = Image.fromarray(mask_vis.astype(np.uint8))
411
+ mask_vis.save(f'{cur_OUTPUT_PATH}/mask_k{kernel_size}_scale{i}.png')
412
+
413
+ cur_mask = cur_mask[None, ...] # [1, 1, f, H, W]
414
+ pose_features[-1] = cur_pose_features[i] * cur_mask + pose_features[-1] * (1 - cur_mask)
415
+
416
+ ### Got blending pose features End ###
417
+
418
+ ##### Warp Noise Start ######
419
+
420
+ if shared_wapring_latents:
421
+ noise = latents_org[0, 0].data.cpu().numpy().copy() #[14, 4, 40, 72]
422
+ noise = np.transpose(noise, (1, 2, 0)) # [40, 72, 4]
423
+
424
+ try:
425
+ warp_noise = Unprojected(noise, depth_down, RTs, H=H, W=W, K=K)
426
+ warp_noise.insert(0, noise)
427
+ except:
428
+ print(f'!!! Noise is too small to warp; mask to bbox')
429
+
430
+ warp_noise = [noise]
431
+ for i in range(1, num_frames):
432
+ shift_x = int(round(P[i][0] - P[0][0]))
433
+ shift_y = int(round(P[i][1] - P[0][1]))
434
+
435
+ cur_noise= roll_with_ignore_multidim(noise, [shift_y, shift_x])
436
+ warp_noise.append(cur_noise)
437
+
438
+ warp_noise = np.stack(warp_noise, axis=0) # [f, h, w, 4]
439
+
440
+ if DEBUG:
441
+ ## visulize warp noise
442
+ warp_noise_vis = torch.tensor(warp_noise)[..., :3] * torch.tensor(warped_masks)
443
+ warp_noise_vis = (warp_noise_vis - warp_noise_vis.min()) / (warp_noise_vis.max() - warp_noise_vis.min())
444
+ warp_noise_vis = (warp_noise_vis * 255.0).to(torch.uint8)
445
+
446
+ torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warp_noise.mp4', warp_noise_vis, fps=10, video_codec='h264', options={'crf': '10'})
447
+
448
+
449
+ warp_latents = torch.tensor(warp_noise).permute(0, 3, 1, 2).to(latents_org.device).to(latents_org.dtype) # [frame, 4, H, W]
450
+ warp_latents = warp_latents.unsqueeze(0) # [1, frame, 4, H, W]
451
+
452
+ warped_masks = torch.tensor(warped_masks).permute(0, 3, 1, 2).unsqueeze(0) # [1, frame, 3, H, W]
453
+ mask_extend = torch.concat([warped_masks, warped_masks[:,:,0:1]], dim=2) # [1, frame, 4, H, W]
454
+ mask_extend = mask_extend.to(latents_org.device).to(latents_org.dtype)
455
+
456
+ warp_latents = warp_latents * mask_extend + latents_org * (1 - mask_extend)
457
+ warp_latents = warp_latents.permute(0, 2, 1, 3, 4)
458
+ random_noise = latents_org.clone().permute(0, 2, 1, 3, 4)
459
+
460
+ filter_shape = warp_latents.shape
461
+
462
+ freq_filter = get_freq_filter(
463
+ filter_shape,
464
+ device = device,
465
+ filter_type='butterworth',
466
+ n=4,
467
+ d_s=ds,
468
+ d_t=dt
469
+ )
470
+
471
+ warp_latents = freq_mix_3d(warp_latents, random_noise, freq_filter)
472
+ warp_latents = warp_latents.permute(0, 2, 1, 3, 4)
473
+
474
+ else:
475
+ warp_latents = latents_org.clone()
476
+
477
+ generator.manual_seed(42)
478
+
479
+ with torch.no_grad():
480
+ result = pipeline(
481
+ image=condition_image,
482
+ pose_embedding=cur_plucker_embedding,
483
+ height=height,
484
+ width=width,
485
+ num_frames=num_frames,
486
+ num_inference_steps=num_inference_steps,
487
+ min_guidance_scale=min_guidance_scale,
488
+ max_guidance_scale=max_guidance_scale,
489
+ do_image_process=True,
490
+ generator=generator,
491
+ output_type='pt',
492
+ pose_features= pose_features,
493
+ latents = warp_latents
494
+ ).frames[0].cpu() #[f, c, h, w]
495
+
496
+
497
+ result = rearrange(result, 'f c h w -> f h w c')
498
+ result = (result * 255.0).to(torch.uint8)
499
+
500
+ video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
501
+ torchvision.io.write_video(video_path, result, fps=10, video_codec='h264', options={'crf': '8'})
502
+
503
+ return video_path
504
+
505
+ # -------------- UI definition --------------
506
+ with gr.Blocks() as demo:
507
+ # layout definition
508
+ gr.Markdown(title)
509
+ gr.Markdown(authors)
510
+ gr.Markdown(affiliation)
511
+ gr.Markdown(important_link)
512
+ gr.Markdown(description)
513
+
514
+
515
+ # with gr.Row():
516
+ # gr.Markdown("""# <center>Repositioning the Subject within Image </center>""")
517
+ mask = gr.State(value=None) # store mask
518
+ removal_mask = gr.State(value=None) # store removal mask
519
+ selected_points = gr.State([]) # store points
520
+ selected_points_text = gr.Textbox(label="Selected Points", visible=False)
521
+
522
+ original_image = gr.State(value=None) # store original input image
523
+ masked_original_image = gr.State(value=None) # store masked input image
524
+ mask_logits = gr.State(value=None) # store mask logits
525
+
526
+ depth = gr.State(value=None) # store depth
527
+ org_depth_image = gr.State(value=None) # store original depth image
528
+
529
+ camera_pose = gr.State(value=None) # store camera pose
530
+
531
+ with gr.Column():
532
+
533
+ outlines = """
534
+ <font size="5"><b>There are total 5 steps to complete the task.</b></font>
535
+ - Step 1: Input an image and Crop it to a suitable size;
536
+ - Step 2: Attain the subject mask;
537
+ - Step 3: Get depth and Draw Trajectory;
538
+ - Step 4: Get camera pose from trajectory or customize it;
539
+ - Step 5: Generate the final video.
540
+ """
541
+
542
+ gr.Markdown(outlines)
543
+
544
+
545
+ with gr.Row():
546
+ with gr.Column():
547
+ # Step 1: Input Image
548
+ step1_dec = """
549
+ <font size="4"><b>Step 1: Input Image</b></font>
550
+ - Select the region using a <mark>bounding box</mark>, aiming for a ratio close to </mark>320:576</mark> (height:width).
551
+ - All provided images in `Examples` are in 320 x 576 resolution. Simply press `Process` to proceed.
552
+ """
553
+ step1 = gr.Markdown(step1_dec)
554
+ raw_input = ImagePrompter(type="pil", label="Raw Image", show_label=True, interactive=True)
555
+ # left_up_point = gr.Textbox(value = "-1 -1", label="Left Up Point", interactive=True)
556
+ process_button = gr.Button("Process")
557
+
558
+ with gr.Column():
559
+ # Step 2: Get Subject Mask
560
+ step2_dec = """
561
+ <font size="4"><b>Step 2: Get Subject Mask</b></font>
562
+ - Use the <mark>bounding boxes</mark> or <mark>paints</mark> to select the subject.
563
+ - Press `Segment Subject` to get the mask. <mark>Can be refined iteratively by updating points<mark>.
564
+ """
565
+ step2 = gr.Markdown(step2_dec)
566
+ canvas = ImagePrompter(type="pil", label="Input Image", show_label=True, interactive=True) # for mask painting
567
+
568
+ select_button = gr.Button("Segment Subject")
569
+
570
+ with gr.Row():
571
+ with gr.Column():
572
+ mask_dec = """
573
+ <font size="4"><b>Mask Result</b></font>
574
+ - Just for visualization purpose. No need to interact.
575
+ """
576
+ mask_vis = gr.Markdown(mask_dec)
577
+ mask_output = gr.Image(type="pil", label="Mask", show_label=True, interactive=False)
578
+ with gr.Column():
579
+ # Step 3: Get Depth and Draw Trajectory
580
+ step3_dec = """
581
+ <font size="4"><b>Step 3: Get Depth and Draw Trajectory</b></font>
582
+ - Press `Get Depth` to get the depth image.
583
+ - Draw the trajectory by selecting points on the depth image. <mark>No more than 14 points</mark>.
584
+ - Press `Undo point` to remove all points.
585
+ """
586
+ step3 = gr.Markdown(step3_dec)
587
+ depth_image = gr.Image(type="pil", label="Depth Image", show_label=True, interactive=False)
588
+ with gr.Row():
589
+ depth_button = gr.Button("Get Depth")
590
+ undo_button = gr.Button("Undo point")
591
+
592
+ with gr.Row():
593
+ with gr.Column():
594
+ # Step 4: Trajectory to Camera Pose or Get Camera Pose
595
+ step4_dec = """
596
+ <font size="4"><b>Step 4: Get camera pose from trajectory or customize it</b></font>
597
+ - Option 1: Transform the 2D trajectory to camera poses with depth. <mark>`Rescale` is used for depth alignment. Larger value can speed up the object motion.</mark>
598
+ - Option 2: Rotate the camera with a specific `Angle`.
599
+ - Option 3: Rotate the camera clockwise or counterclockwise with a specific `Angle`.
600
+ - Option 4: Translate the camera with `Tx` (<mark>Pan Left/Right</mark>), `Ty` (<mark>Pan Up/Down</mark>), `Tz` (<mark>Zoom In/Out</mark>) and `Speed`.
601
+ """
602
+ step4 = gr.Markdown(step4_dec)
603
+ camera_pose_vis = gr.Plot(None, label='Camera Pose')
604
+ with gr.Row():
605
+ with gr.Column():
606
+ speed = gr.Slider(minimum=0.1, maximum=10, step=0.1, value=1.0, label="Speed", interactive=True)
607
+ rescale = gr.Slider(minimum=0.0, maximum=10, step=0.1, value=1.0, label="Rescale", interactive=True)
608
+ # traj2pose_button = gr.Button("Option1: Trajectory to Camera Pose")
609
+
610
+ angle = gr.Slider(minimum=-360, maximum=360, step=1, value=60, label="Angle", interactive=True)
611
+ # rotation_button = gr.Button("Option2: Rotate")
612
+ # clockwise_button = gr.Button("Option3: Clockwise")
613
+ with gr.Column():
614
+
615
+ Tx = gr.Slider(minimum=-1, maximum=1, step=1, value=0, label="Tx", interactive=True)
616
+ Ty = gr.Slider(minimum=-1, maximum=1, step=1, value=0, label="Ty", interactive=True)
617
+ Tz = gr.Slider(minimum=-1, maximum=1, step=1, value=0, label="Tz", interactive=True)
618
+ # translation_button = gr.Button("Option4: Translate")
619
+ with gr.Row():
620
+ camera_option = gr.Radio(choices = CAMERA_MODE, label='Camera Options', value=CAMERA_MODE[0], interactive=True)
621
+ with gr.Row():
622
+ get_camera_pose_button = gr.Button("Get Camera Pose")
623
+
624
+ with gr.Column():
625
+ # Step 5: Get the final generated video
626
+ step5_dec = """
627
+ <font size="4"><b>Step 5: Get the final generated video</b></font>
628
+ - 3 modes for background: <mark>Fixed</mark>, <mark>Reverse</mark>, <mark>Free</mark>.
629
+ - Enable <mark>Scale-wise Masks</mark> for better object control.
630
+ - Option to enable <mark>Shared Warping Latents</mark> and set <mark>stop frequency</mark> for spatial (`ds`) and temporal (`dt`) dimensions. Larger stop frequency will lead to artifacts.
631
+ """
632
+ step5 = gr.Markdown(step5_dec)
633
+ generated_video = gr.Video(None, label='Generated Video')
634
+
635
+ with gr.Row():
636
+ seed = gr.Textbox(value = "42", label="Seed", interactive=True)
637
+ # num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, value=25, label="Number of Inference Steps", interactive=True)
638
+ bg_mode = gr.Radio(choices = ["Fixed", "Reverse", "Free"], label="Background Mode", value="Fixed", interactive=True)
639
+ # swl_mode = gr.Radio(choices = ["Enable SWL", "Disable SWL"], label="Shared Warping Latent", value="Disable SWL", interactive=True)
640
+ scale_wise_masks = gr.Checkbox(label="Enable Scale-wise Masks", interactive=True, value=True)
641
+ with gr.Row():
642
+ with gr.Column():
643
+ shared_wapring_latents = gr.Checkbox(label="Enable Shared Warping Latents", interactive=True)
644
+ with gr.Column():
645
+ ds = gr.Slider(minimum=0.0, maximum=1, step=0.1, value=0.5, label="ds", interactive=True)
646
+ dt = gr.Slider(minimum=0.0, maximum=1, step=0.1, value=0.5, label="dt", interactive=True)
647
+
648
+ generated_button = gr.Button("Generate")
649
+
650
+
651
+
652
+ # # event definition
653
+ process_button.click(
654
+ fn = process_image,
655
+ inputs = [raw_input],
656
+ outputs = [original_image, canvas]
657
+ )
658
+
659
+ select_button.click(
660
+ segment,
661
+ [canvas, original_image, mask_logits],
662
+ [mask, mask_output, masked_original_image, mask_logits]
663
+ )
664
+
665
+ depth_button.click(
666
+ get_depth,
667
+ [original_image, selected_points],
668
+ [depth, depth_image, org_depth_image]
669
+ )
670
+
671
+ depth_image.select(
672
+ get_points,
673
+ [depth_image, selected_points],
674
+ [depth_image, selected_points],
675
+ )
676
+ undo_button.click(
677
+ undo_points,
678
+ [org_depth_image],
679
+ [depth_image, selected_points]
680
+ )
681
+
682
+ get_camera_pose_button.click(
683
+ get_camera_pose(CAMERA_MODE),
684
+ [camera_option, selected_points, depth, mask, rescale, angle, Tx, Ty, Tz, speed],
685
+ [camera_pose, camera_pose_vis, rescale]
686
+ )
687
+
688
+ generated_button.click(
689
+ run_objctrl_2_5d,
690
+ [
691
+ original_image,
692
+ mask,
693
+ depth,
694
+ camera_pose,
695
+ bg_mode,
696
+ shared_wapring_latents,
697
+ scale_wise_masks,
698
+ rescale,
699
+ seed,
700
+ ds,
701
+ dt,
702
+ # num_inference_steps
703
+ ],
704
+ [generated_video],
705
+ )
706
+
707
+ gr.Examples(
708
+ examples=examples,
709
+ inputs=[
710
+ raw_input,
711
+ rescale,
712
+ speed,
713
+ angle,
714
+ Tx,
715
+ Ty,
716
+ Tz,
717
+ camera_option,
718
+ bg_mode,
719
+ shared_wapring_latents,
720
+ scale_wise_masks,
721
+ ds,
722
+ dt,
723
+ seed,
724
+ selected_points_text # selected_points
725
+ ],
726
+ outputs=[generated_video],
727
+ examples_per_page=10
728
+ )
729
+
730
+ selected_points_text.change(
731
+ sync_points,
732
+ inputs=[selected_points_text],
733
+ outputs=[selected_points]
734
+ )
735
+
736
+
737
+ gr.Markdown(article)
738
+
739
+
740
+ demo.queue().launch(share=True)
app.py CHANGED
@@ -1,12 +1,18 @@
1
- import spaces
 
 
 
 
2
  import os
3
  import gradio as gr
 
 
4
 
5
  import torch
6
  from gradio_image_prompter import ImagePrompter
7
  from sam2.sam2_image_predictor import SAM2ImagePredictor
8
  from omegaconf import OmegaConf
9
- from PIL import Image
10
  import numpy as np
11
  from copy import deepcopy
12
  import cv2
@@ -16,7 +22,7 @@ import torchvision
16
  from einops import rearrange
17
  import tempfile
18
 
19
- from objctrl_2_5d.utils.ui_utils import process_image, get_camera_pose, get_subject_points, get_points, undo_points, mask_image
20
  from ZoeDepth.zoedepth.utils.misc import colorize
21
 
22
  from cameractrl.inference import get_pipeline
@@ -25,7 +31,6 @@ from objctrl_2_5d.utils.examples import examples, sync_points
25
  from objctrl_2_5d.utils.objmask_util import RT2Plucker, Unprojected, roll_with_ignore_multidim, dilate_mask_pytorch
26
  from objctrl_2_5d.utils.filter_utils import get_freq_filter, freq_mix_3d
27
 
28
-
29
  ### Title and Description ###
30
  #### Description ####
31
  title = r"""<h1 align="center">ObjCtrl-2.5D: Training-free Object Control with Camera Poses</h1>"""
@@ -85,9 +90,40 @@ If you have any questions, please feel free to reach me out at <b>zhouzi1212@gma
85
 
86
  """
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  # -------------- initialization --------------
89
 
90
- CAMERA_MODE = ["Traj2Cam", "Rotate", "Clockwise", "Translate"]
 
91
 
92
  # select the device for computation
93
  if torch.cuda.is_available():
@@ -96,11 +132,9 @@ elif torch.backends.mps.is_available():
96
  device = torch.device("mps")
97
  else:
98
  device = torch.device("cpu")
99
- device = torch.device("cuda")
100
- print(f"Force device to {device} due to ZeroGPU")
101
  print(f"using device: {device}")
102
 
103
- # segmentation model
104
  segmentor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny", cache_dir="ckpt", device=device)
105
 
106
  # depth model
@@ -126,7 +160,7 @@ pipeline = get_pipeline(model_id, "unet", model_config['down_block_types'], mode
126
  # pipeline = None
127
 
128
  ### run the demo ##
129
- @spaces.GPU(duration=5)
130
  def segment(canvas, image, logits):
131
  if logits is not None:
132
  logits *= 32.0
@@ -165,28 +199,9 @@ def segment(canvas, image, logits):
165
  masked_img = mask_image(image, mask[0], color=[252, 140, 90], alpha=0.9)
166
  masked_img = Image.fromarray(masked_img)
167
 
168
- return mask[0], masked_img, masked_img, logits / 32.0
169
-
170
- @spaces.GPU(duration=5)
171
- def get_depth(image, points):
172
-
173
- depth = d_model_NK.infer_pil(image)
174
- colored_depth = colorize(depth, cmap='gray_r') # [h, w, 4] 0-255
175
-
176
- depth_img = deepcopy(colored_depth[:, :, :3])
177
- if len(points) > 0:
178
- for idx, point in enumerate(points):
179
- if idx % 2 == 0:
180
- cv2.circle(depth_img, tuple(point), 10, (255, 0, 0), -1)
181
- else:
182
- cv2.circle(depth_img, tuple(point), 10, (0, 0, 255), -1)
183
- if idx > 0:
184
- cv2.arrowedLine(depth_img, points[idx-1], points[idx], (255, 255, 255), 4, tipLength=0.5)
185
-
186
- return depth, depth_img, colored_depth[:, :, :3]
187
 
188
-
189
- @spaces.GPU(duration=80)
190
  def run_objctrl_2_5d(condition_image,
191
  mask,
192
  depth,
@@ -198,35 +213,6 @@ def run_objctrl_2_5d(condition_image,
198
  seed,
199
  ds, dt,
200
  num_inference_steps=25):
201
-
202
- DEBUG = False
203
-
204
- if DEBUG:
205
- cur_OUTPUT_PATH = 'outputs/tmp'
206
- os.makedirs(cur_OUTPUT_PATH, exist_ok=True)
207
-
208
- # num_inference_steps=25
209
- min_guidance_scale = 1.0
210
- max_guidance_scale = 3.0
211
-
212
- area_ratio = 0.3
213
- depth_scale_ = 5.2
214
- center_margin = 10
215
-
216
- height, width = 320, 576
217
- num_frames = 14
218
-
219
- intrinsics = np.array([[float(width), float(width), float(width) / 2, float(height) / 2]])
220
- intrinsics = np.repeat(intrinsics, num_frames, axis=0) # [n_frame, 4]
221
- fx = intrinsics[0, 0] / width
222
- fy = intrinsics[0, 1] / height
223
- cx = intrinsics[0, 2] / width
224
- cy = intrinsics[0, 3] / height
225
-
226
- down_scale = 8
227
- H, W = height // down_scale, width // down_scale
228
- K = np.array([[width / down_scale, 0, W / 2], [0, width / down_scale, H / 2], [0, 0, 1]])
229
-
230
  seed = int(seed)
231
 
232
  center_h_margin, center_w_margin = center_margin, center_margin
@@ -288,7 +274,7 @@ def run_objctrl_2_5d(condition_image,
288
  fix_pose_features = None
289
 
290
  #### preparing mask
291
-
292
  mask = Image.fromarray(mask)
293
  mask = mask.resize((W, H))
294
  mask = np.array(mask).astype(np.float32)
@@ -500,6 +486,97 @@ def run_objctrl_2_5d(condition_image,
500
 
501
  return video_path
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  # -------------- UI definition --------------
504
  with gr.Blocks() as demo:
505
  # layout definition
@@ -513,12 +590,16 @@ with gr.Blocks() as demo:
513
  # with gr.Row():
514
  # gr.Markdown("""# <center>Repositioning the Subject within Image </center>""")
515
  mask = gr.State(value=None) # store mask
 
 
516
  removal_mask = gr.State(value=None) # store removal mask
517
  selected_points = gr.State([]) # store points
518
  selected_points_text = gr.Textbox(label="Selected Points", visible=False)
 
 
519
 
520
  original_image = gr.State(value=None) # store original input image
521
- masked_original_image = gr.State(value=None) # store masked input image
522
  mask_logits = gr.State(value=None) # store mask logits
523
 
524
  depth = gr.State(value=None) # store depth
@@ -526,14 +607,22 @@ with gr.Blocks() as demo:
526
 
527
  camera_pose = gr.State(value=None) # store camera pose
528
 
 
 
 
 
 
 
 
 
529
  with gr.Column():
530
 
531
  outlines = """
532
  <font size="5"><b>There are total 5 steps to complete the task.</b></font>
533
- - Step 1: Input an image and Crop it to a suitable size;
534
  - Step 2: Attain the subject mask;
535
- - Step 3: Get depth and Draw Trajectory;
536
- - Step 4: Get camera pose from trajectory or customize it;
537
  - Step 5: Generate the final video.
538
  """
539
 
@@ -545,125 +634,92 @@ with gr.Blocks() as demo:
545
  # Step 1: Input Image
546
  step1_dec = """
547
  <font size="4"><b>Step 1: Input Image</b></font>
548
- - Select the region using a <mark>bounding box</mark>, aiming for a ratio close to </mark>320:576</mark> (height:width).
549
- - All provided images in `Examples` are in 320 x 576 resolution. Simply press `Process` to proceed.
550
  """
551
  step1 = gr.Markdown(step1_dec)
552
  raw_input = ImagePrompter(type="pil", label="Raw Image", show_label=True, interactive=True)
553
- # left_up_point = gr.Textbox(value = "-1 -1", label="Left Up Point", interactive=True)
 
 
 
 
 
 
554
  process_button = gr.Button("Process")
555
 
556
  with gr.Column():
557
  # Step 2: Get Subject Mask
558
  step2_dec = """
559
  <font size="4"><b>Step 2: Get Subject Mask</b></font>
560
- - Use the <mark>bounding boxes</mark> or <mark>paints</mark> to select the subject.
561
- - Press `Segment Subject` to get the mask. <mark>Can be refined iteratively by updating points<mark>.
562
  """
563
  step2 = gr.Markdown(step2_dec)
564
  canvas = ImagePrompter(type="pil", label="Input Image", show_label=True, interactive=True) # for mask painting
565
 
 
 
 
 
 
 
566
  select_button = gr.Button("Segment Subject")
567
 
568
- with gr.Row():
569
- with gr.Column():
570
- mask_dec = """
571
- <font size="4"><b>Mask Result</b></font>
572
- - Just for visualization purpose. No need to interact.
573
- """
574
- mask_vis = gr.Markdown(mask_dec)
575
- mask_output = gr.Image(type="pil", label="Mask", show_label=True, interactive=False)
576
  with gr.Column():
577
  # Step 3: Get Depth and Draw Trajectory
578
  step3_dec = """
579
- <font size="4"><b>Step 3: Get Depth and Draw Trajectory</b></font>
580
- - Press `Get Depth` to get the depth image.
581
- - Draw the trajectory by selecting points on the depth image. <mark>No more than 14 points</mark>.
582
- - Press `Undo point` to remove all points.
583
  """
584
  step3 = gr.Markdown(step3_dec)
585
  depth_image = gr.Image(type="pil", label="Depth Image", show_label=True, interactive=False)
586
- with gr.Row():
587
- depth_button = gr.Button("Get Depth")
588
- undo_button = gr.Button("Undo point")
589
-
 
 
 
 
 
 
590
  with gr.Row():
 
591
  with gr.Column():
592
  # Step 4: Trajectory to Camera Pose or Get Camera Pose
593
  step4_dec = """
594
- <font size="4"><b>Step 4: Get camera pose from trajectory or customize it</b></font>
595
- - Option 1: Transform the 2D trajectory to camera poses with depth. <mark>`Rescale` is used for depth alignment. Larger value can speed up the object motion.</mark>
596
- - Option 2: Rotate the camera with a specific `Angle`.
597
- - Option 3: Rotate the camera clockwise or counterclockwise with a specific `Angle`.
598
- - Option 4: Translate the camera with `Tx` (<mark>Pan Left/Right</mark>), `Ty` (<mark>Pan Up/Down</mark>), `Tz` (<mark>Zoom In/Out</mark>) and `Speed`.
599
  """
600
  step4 = gr.Markdown(step4_dec)
601
  camera_pose_vis = gr.Plot(None, label='Camera Pose')
602
- with gr.Row():
603
- with gr.Column():
604
- speed = gr.Slider(minimum=0.1, maximum=10, step=0.1, value=1.0, label="Speed", interactive=True)
605
- rescale = gr.Slider(minimum=0.0, maximum=10, step=0.1, value=1.0, label="Rescale", interactive=True)
606
- # traj2pose_button = gr.Button("Option1: Trajectory to Camera Pose")
607
-
608
- angle = gr.Slider(minimum=-360, maximum=360, step=1, value=60, label="Angle", interactive=True)
609
- # rotation_button = gr.Button("Option2: Rotate")
610
- # clockwise_button = gr.Button("Option3: Clockwise")
611
- with gr.Column():
612
-
613
- Tx = gr.Slider(minimum=-1, maximum=1, step=1, value=0, label="Tx", interactive=True)
614
- Ty = gr.Slider(minimum=-1, maximum=1, step=1, value=0, label="Ty", interactive=True)
615
- Tz = gr.Slider(minimum=-1, maximum=1, step=1, value=0, label="Tz", interactive=True)
616
- # translation_button = gr.Button("Option4: Translate")
617
- with gr.Row():
618
- camera_option = gr.Radio(choices = CAMERA_MODE, label='Camera Options', value=CAMERA_MODE[0], interactive=True)
619
- with gr.Row():
620
- get_camera_pose_button = gr.Button("Get Camera Pose")
621
 
622
  with gr.Column():
623
  # Step 5: Get the final generated video
624
  step5_dec = """
625
  <font size="4"><b>Step 5: Get the final generated video</b></font>
626
- - 3 modes for background: <mark>Fixed</mark>, <mark>Reverse</mark>, <mark>Free</mark>.
627
- - Enable <mark>Scale-wise Masks</mark> for better object control.
628
- - Option to enable <mark>Shared Warping Latents</mark> and set <mark>stop frequency</mark> for spatial (`ds`) and temporal (`dt`) dimensions. Larger stop frequency will lead to artifacts.
629
  """
630
  step5 = gr.Markdown(step5_dec)
631
  generated_video = gr.Video(None, label='Generated Video')
632
 
633
- with gr.Row():
634
- seed = gr.Textbox(value = "42", label="Seed", interactive=True)
635
- # num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, value=25, label="Number of Inference Steps", interactive=True)
636
- bg_mode = gr.Radio(choices = ["Fixed", "Reverse", "Free"], label="Background Mode", value="Fixed", interactive=True)
637
- # swl_mode = gr.Radio(choices = ["Enable SWL", "Disable SWL"], label="Shared Warping Latent", value="Disable SWL", interactive=True)
638
- scale_wise_masks = gr.Checkbox(label="Enable Scale-wise Masks", interactive=True, value=True)
639
- with gr.Row():
640
- with gr.Column():
641
- shared_wapring_latents = gr.Checkbox(label="Enable Shared Warping Latents", interactive=True)
642
- with gr.Column():
643
- ds = gr.Slider(minimum=0.0, maximum=1, step=0.1, value=0.5, label="ds", interactive=True)
644
- dt = gr.Slider(minimum=0.0, maximum=1, step=0.1, value=0.5, label="dt", interactive=True)
645
 
646
  generated_button = gr.Button("Generate")
647
 
 
648
 
649
 
650
  # # event definition
651
  process_button.click(
652
  fn = process_image,
653
- inputs = [raw_input],
654
- outputs = [original_image, canvas]
655
  )
656
 
657
  select_button.click(
658
  segment,
659
  [canvas, original_image, mask_logits],
660
- [mask, mask_output, masked_original_image, mask_logits]
661
- )
662
-
663
- depth_button.click(
664
- get_depth,
665
- [original_image, selected_points],
666
- [depth, depth_image, org_depth_image]
667
  )
668
 
669
  depth_image.select(
@@ -677,9 +733,15 @@ with gr.Blocks() as demo:
677
  [depth_image, selected_points]
678
  )
679
 
680
- get_camera_pose_button.click(
 
 
 
 
 
 
681
  get_camera_pose(CAMERA_MODE),
682
- [camera_option, selected_points, depth, mask, rescale, angle, Tx, Ty, Tz, speed],
683
  [camera_pose, camera_pose_vis, rescale]
684
  )
685
 
@@ -701,35 +763,44 @@ with gr.Blocks() as demo:
701
  ],
702
  [generated_video],
703
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
704
 
705
  gr.Examples(
706
  examples=examples,
707
  inputs=[
708
  raw_input,
709
- rescale,
710
- speed,
711
- angle,
712
- Tx,
713
- Ty,
714
- Tz,
715
  camera_option,
716
  bg_mode,
717
  shared_wapring_latents,
718
- scale_wise_masks,
719
- ds,
720
- dt,
721
- seed,
722
- selected_points_text # selected_points
723
  ],
724
- outputs=[generated_video],
725
- examples_per_page=10
726
  )
727
 
728
  selected_points_text.change(
729
- sync_points,
730
- inputs=[selected_points_text],
731
- outputs=[selected_points]
732
  )
 
 
733
 
734
 
735
  gr.Markdown(article)
 
1
+ try:
2
+ import spaces
3
+ except:
4
+ pass
5
+
6
  import os
7
  import gradio as gr
8
+ import json
9
+ import ast
10
 
11
  import torch
12
  from gradio_image_prompter import ImagePrompter
13
  from sam2.sam2_image_predictor import SAM2ImagePredictor
14
  from omegaconf import OmegaConf
15
+ from PIL import Image, ImageDraw
16
  import numpy as np
17
  from copy import deepcopy
18
  import cv2
 
22
  from einops import rearrange
23
  import tempfile
24
 
25
+ from objctrl_2_5d.utils.ui_utils import process_image, get_camera_pose, get_subject_points, get_points, undo_points, mask_image, traj2cam, get_mid_params
26
  from ZoeDepth.zoedepth.utils.misc import colorize
27
 
28
  from cameractrl.inference import get_pipeline
 
31
  from objctrl_2_5d.utils.objmask_util import RT2Plucker, Unprojected, roll_with_ignore_multidim, dilate_mask_pytorch
32
  from objctrl_2_5d.utils.filter_utils import get_freq_filter, freq_mix_3d
33
 
 
34
  ### Title and Description ###
35
  #### Description ####
36
  title = r"""<h1 align="center">ObjCtrl-2.5D: Training-free Object Control with Camera Poses</h1>"""
 
90
 
91
  """
92
 
93
+ # pre-defined parameters
94
+ DEBUG = False
95
+
96
+ if DEBUG:
97
+ cur_OUTPUT_PATH = 'outputs/tmp'
98
+ os.makedirs(cur_OUTPUT_PATH, exist_ok=True)
99
+
100
+ # num_inference_steps=25
101
+ min_guidance_scale = 1.0
102
+ max_guidance_scale = 3.0
103
+
104
+ area_ratio = 0.3
105
+ depth_scale_ = 5.2
106
+ center_margin = 10
107
+
108
+ height, width = 320, 576
109
+ num_frames = 14
110
+
111
+ intrinsics = np.array([[float(width), float(width), float(width) / 2, float(height) / 2]])
112
+ intrinsics = np.repeat(intrinsics, num_frames, axis=0) # [n_frame, 4]
113
+ fx = intrinsics[0, 0] / width
114
+ fy = intrinsics[0, 1] / height
115
+ cx = intrinsics[0, 2] / width
116
+ cy = intrinsics[0, 3] / height
117
+
118
+ down_scale = 8
119
+ H, W = height // down_scale, width // down_scale
120
+ K = np.array([[width / down_scale, 0, W / 2], [0, width / down_scale, H / 2], [0, 0, 1]])
121
+
122
+
123
  # -------------- initialization --------------
124
 
125
+ # CAMERA_MODE = ["Traj2Cam", "Rotate", "Clockwise", "Translate"]
126
+ CAMERA_MODE = ["None", "ZoomIn", "ZoomOut", "PanRight", "PanLeft", "TiltUp", "TiltDown", "ClockWise", "Anti-CW", "Rotate60"]
127
 
128
  # select the device for computation
129
  if torch.cuda.is_available():
 
132
  device = torch.device("mps")
133
  else:
134
  device = torch.device("cpu")
 
 
135
  print(f"using device: {device}")
136
 
137
+ # # segmentation model
138
  segmentor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny", cache_dir="ckpt", device=device)
139
 
140
  # depth model
 
160
  # pipeline = None
161
 
162
  ### run the demo ##
163
+ # @spaces.GPU(duration=5)
164
  def segment(canvas, image, logits):
165
  if logits is not None:
166
  logits *= 32.0
 
199
  masked_img = mask_image(image, mask[0], color=[252, 140, 90], alpha=0.9)
200
  masked_img = Image.fromarray(masked_img)
201
 
202
+ return mask[0], {'image': masked_img, 'points': points}, logits / 32.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ # @spaces.GPU(duration=80)
 
205
  def run_objctrl_2_5d(condition_image,
206
  mask,
207
  depth,
 
213
  seed,
214
  ds, dt,
215
  num_inference_steps=25):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  seed = int(seed)
217
 
218
  center_h_margin, center_w_margin = center_margin, center_margin
 
274
  fix_pose_features = None
275
 
276
  #### preparing mask
277
+
278
  mask = Image.fromarray(mask)
279
  mask = mask.resize((W, H))
280
  mask = np.array(mask).astype(np.float32)
 
486
 
487
  return video_path
488
 
489
+
490
+ # UI function
491
+ # @spaces.GPU(duration=5)
492
+ def process_image(raw_image, trajectory_points):
493
+
494
+ image, points = raw_image['image'], raw_image['points']
495
+
496
+ print(points)
497
+
498
+ try:
499
+ assert(len(points)) == 1, "Please draw only one bbox"
500
+ [x1, y1, _, x2, y2, _] = points[0]
501
+
502
+ image = image.crop((x1, y1, x2, y2))
503
+ image = image.resize((width, height))
504
+ except:
505
+ image = image.resize((width, height))
506
+
507
+ depth = d_model_NK.infer_pil(image)
508
+ colored_depth = colorize(depth, cmap='gray_r') # [h, w, 4] 0-255
509
+
510
+ depth_img = deepcopy(colored_depth[:, :, :3])
511
+ if len(trajectory_points) > 0:
512
+ for idx, point in enumerate(trajectory_points):
513
+ if idx % 2 == 0:
514
+ cv2.circle(depth_img, tuple(point), 10, (255, 0, 0), -1)
515
+ else:
516
+ cv2.circle(depth_img, tuple(point), 10, (0, 0, 255), -1)
517
+ if idx > 0:
518
+ line_length = np.sqrt((trajectory_points[idx][0] - trajectory_points[idx-1][0])**2 + (trajectory_points[idx][1] - trajectory_points[idx-1][1])**2)
519
+ arrow_head_length = 10
520
+ tip_length = arrow_head_length / line_length
521
+ cv2.arrowedLine(depth_img, trajectory_points[idx-1], trajectory_points[idx], (0, 255, 0), 4, tipLength=tip_length)
522
+
523
+ return image, {'image': image}, depth, depth_img, colored_depth[:, :, :3]
524
+
525
+
526
+
527
+ def draw_points_on_image(img, points):
528
+ # img = Image.fromarray(np.array(image))
529
+ draw = ImageDraw.Draw(img)
530
+
531
+ for p in points:
532
+ x1, y1, _, x2, y2, _ = p
533
+
534
+ if x2 == 0 and y2 == 0:
535
+ # Point: 青色点带黑边
536
+ point_radius = 4
537
+ draw.ellipse(
538
+ (x1 - point_radius, y1 - point_radius, x1 + point_radius, y1 + point_radius),
539
+ fill="cyan", outline="black", width=1
540
+ )
541
+ else:
542
+ # Bounding Box: 黑色矩形框
543
+ draw.rectangle([x1, y1, x2, y2], outline="black", width=3)
544
+
545
+ return img
546
+
547
+ # @spaces.GPU(duration=10)
548
+ def from_examples(raw_input, raw_image_points, canvas, seg_image_points, selected_points_text, camera_option, mask_bk):
549
+
550
+ selected_points = ast.literal_eval(selected_points_text)
551
+ mask = np.array(mask_bk)
552
+ mask = mask[:,:,0] > 0
553
+ selected_points = ast.literal_eval(selected_points_text)
554
+
555
+ image, _, depth, depth_img, colored_depth = process_image(raw_input, selected_points)
556
+
557
+ # get camera pose
558
+ if camera_option == "None":
559
+ # traj2came
560
+ rescale = 1.0
561
+ camera_pose, camera_pose_vis, rescale, _ = traj2cam(selected_points, depth , rescale)
562
+ else:
563
+ rescale = 0.0
564
+ angle = 60
565
+ speed = 4.0
566
+ camera_pose, camera_pose_vis, rescale = get_camera_pose(CAMERA_MODE)(camera_option, depth, mask, rescale, angle, speed)
567
+
568
+ raw_image_points = ast.literal_eval(raw_image_points)
569
+ seg_image_points = ast.literal_eval(seg_image_points)
570
+
571
+ raw_image = draw_points_on_image(raw_input['image'], raw_image_points)
572
+ seg_image = draw_points_on_image(canvas['image'], seg_image_points)
573
+
574
+ return image, mask, depth, depth_img, colored_depth, camera_pose, \
575
+ camera_pose_vis, rescale, selected_points, \
576
+ gr.update(value={'image': raw_image, 'points': raw_image_points}), \
577
+ gr.update(value={'image': seg_image, 'points': seg_image_points}), \
578
+
579
+
580
  # -------------- UI definition --------------
581
  with gr.Blocks() as demo:
582
  # layout definition
 
590
  # with gr.Row():
591
  # gr.Markdown("""# <center>Repositioning the Subject within Image </center>""")
592
  mask = gr.State(value=None) # store mask
593
+ mask_bk = gr.Image(type="pil", label="Mask", show_label=True, interactive=False, visible=False)
594
+
595
  removal_mask = gr.State(value=None) # store removal mask
596
  selected_points = gr.State([]) # store points
597
  selected_points_text = gr.Textbox(label="Selected Points", visible=False)
598
+ raw_image_points = gr.Textbox(label="Raw Image Points", visible=False)
599
+ seg_image_points = gr.Textbox(label="Segment Image Points", visible=False)
600
 
601
  original_image = gr.State(value=None) # store original input image
602
+ # masked_original_image = gr.State(value=None) # store masked input image
603
  mask_logits = gr.State(value=None) # store mask logits
604
 
605
  depth = gr.State(value=None) # store depth
 
607
 
608
  camera_pose = gr.State(value=None) # store camera pose
609
 
610
+ rescale = gr.Slider(minimum=0.0, maximum=10, step=0.1, value=1.0, label="Rescale", interactive=True, visible=False)
611
+ angle = gr.Slider(minimum=-360, maximum=360, step=1, value=60, label="Angle", interactive=True, visible=False)
612
+
613
+ seed = gr.Textbox(value = "42", label="Seed", interactive=True, visible=False)
614
+ scale_wise_masks = gr.Checkbox(label="Enable Scale-wise Masks", interactive=True, value=True, visible=False)
615
+ ds = gr.Slider(minimum=0.0, maximum=1, step=0.1, value=0.25, label="ds", interactive=True, visible=False)
616
+ dt = gr.Slider(minimum=0.0, maximum=1, step=0.1, value=0.1, label="dt", interactive=True, visible=False)
617
+
618
  with gr.Column():
619
 
620
  outlines = """
621
  <font size="5"><b>There are total 5 steps to complete the task.</b></font>
622
+ - Step 1: Input an image and Crop it to a suitable size and attained depth;
623
  - Step 2: Attain the subject mask;
624
+ - Step 3: Draw trajectory on depth map or skip to use camera pose;
625
+ - Step 4: Select camera poses or skip.
626
  - Step 5: Generate the final video.
627
  """
628
 
 
634
  # Step 1: Input Image
635
  step1_dec = """
636
  <font size="4"><b>Step 1: Input Image</b></font>
 
 
637
  """
638
  step1 = gr.Markdown(step1_dec)
639
  raw_input = ImagePrompter(type="pil", label="Raw Image", show_label=True, interactive=True)
640
+
641
+ step1_notes = """
642
+ - Select the region using a <mark>bounding box</mark>, aiming for a ratio close to </mark>320:576</mark> (height:width).
643
+ - If the input is in 320 x 576, press `Process` directly.
644
+ """
645
+ notes = gr.Markdown(step1_notes)
646
+
647
  process_button = gr.Button("Process")
648
 
649
  with gr.Column():
650
  # Step 2: Get Subject Mask
651
  step2_dec = """
652
  <font size="4"><b>Step 2: Get Subject Mask</b></font>
 
 
653
  """
654
  step2 = gr.Markdown(step2_dec)
655
  canvas = ImagePrompter(type="pil", label="Input Image", show_label=True, interactive=True) # for mask painting
656
 
657
+ step2_notes = """
658
+ - Use the <mark>bounding boxes</mark> or <mark>points</mark> to select the subject.
659
+ - Press `Segment Subject` to get the mask. <mark>Can be refined iteratively by updating points<mark>.
660
+ """
661
+ notes = gr.Markdown(step2_notes)
662
+
663
  select_button = gr.Button("Segment Subject")
664
 
 
 
 
 
 
 
 
 
665
  with gr.Column():
666
  # Step 3: Get Depth and Draw Trajectory
667
  step3_dec = """
668
+ <font size="4"><b>Step 3: Draw Trajectory on Depth or <mark>SKIP</mark></b></font>
669
+
 
 
670
  """
671
  step3 = gr.Markdown(step3_dec)
672
  depth_image = gr.Image(type="pil", label="Depth Image", show_label=True, interactive=False)
673
+
674
+ step3_dec = """
675
+ - Selecting points on the depth image. <mark>No more than 14 points</mark>.
676
+ - Press `Undo point` to remove all points. Press `Traj2Cam` to get camera poses.
677
+ """
678
+ notes = gr.Markdown(step3_dec)
679
+
680
+ undo_button = gr.Button("Undo point")
681
+ traj2cam_button = gr.Button("Traj2Cam")
682
+
683
  with gr.Row():
684
+
685
  with gr.Column():
686
  # Step 4: Trajectory to Camera Pose or Get Camera Pose
687
  step4_dec = """
688
+ <font size="4"><b>Step 4: Get Customized Camera Poses or <mark>Skip</mark></b></font>
 
 
 
 
689
  """
690
  step4 = gr.Markdown(step4_dec)
691
  camera_pose_vis = gr.Plot(None, label='Camera Pose')
692
+ camera_option = gr.Radio(choices = CAMERA_MODE, label='Camera Options', value=CAMERA_MODE[0], interactive=True)
693
+ speed = gr.Slider(minimum=0.1, maximum=10, step=0.1, value=4.0, label="Speed", interactive=True, visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
 
695
  with gr.Column():
696
  # Step 5: Get the final generated video
697
  step5_dec = """
698
  <font size="4"><b>Step 5: Get the final generated video</b></font>
 
 
 
699
  """
700
  step5 = gr.Markdown(step5_dec)
701
  generated_video = gr.Video(None, label='Generated Video')
702
 
703
+ # with gr.Row():
704
+ bg_mode = gr.Radio(choices = ["Fixed", "Reverse", "Free"], label="Background Mode", value="Fixed", interactive=True)
705
+ shared_wapring_latents = gr.Checkbox(label="Enable Shared Warping Latents", interactive=True, value=False, visible=True)
 
 
 
 
 
 
 
 
 
706
 
707
  generated_button = gr.Button("Generate")
708
 
709
+ get_mid_params_button = gr.Button("Get Mid Params")
710
 
711
 
712
  # # event definition
713
  process_button.click(
714
  fn = process_image,
715
+ inputs = [raw_input, selected_points],
716
+ outputs = [original_image, canvas, depth, depth_image, org_depth_image]
717
  )
718
 
719
  select_button.click(
720
  segment,
721
  [canvas, original_image, mask_logits],
722
+ [mask, canvas, mask_logits]
 
 
 
 
 
 
723
  )
724
 
725
  depth_image.select(
 
733
  [depth_image, selected_points]
734
  )
735
 
736
+ traj2cam_button.click(
737
+ traj2cam,
738
+ [selected_points, depth, rescale],
739
+ [camera_pose, camera_pose_vis, rescale, camera_option]
740
+ )
741
+
742
+ camera_option.change(
743
  get_camera_pose(CAMERA_MODE),
744
+ [camera_option, depth, mask, rescale, angle, speed],
745
  [camera_pose, camera_pose_vis, rescale]
746
  )
747
 
 
763
  ],
764
  [generated_video],
765
  )
766
+
767
+ get_mid_params_button.click(
768
+ get_mid_params,
769
+ [raw_input, canvas, mask, selected_points, camera_option, bg_mode, shared_wapring_latents, generated_video]
770
+ )
771
+
772
+ ## Get examples
773
+ with open('./assets/examples/examples.json', 'r') as f:
774
+ examples = json.load(f)
775
+ print(examples)
776
+
777
+ # examples = [examples]
778
+ examples = [v for k, v in examples.items()]
779
 
780
  gr.Examples(
781
  examples=examples,
782
  inputs=[
783
  raw_input,
784
+ raw_image_points,
785
+ canvas,
786
+ seg_image_points,
787
+ mask_bk,
788
+ selected_points_text, # selected_points
 
789
  camera_option,
790
  bg_mode,
791
  shared_wapring_latents,
792
+ generated_video
 
 
 
 
793
  ],
794
+ examples_per_page=20
 
795
  )
796
 
797
  selected_points_text.change(
798
+ from_examples,
799
+ inputs=[raw_input, raw_image_points, canvas, seg_image_points, selected_points_text, camera_option, mask_bk],
800
+ outputs=[original_image, mask, depth, depth_image, org_depth_image, camera_pose, camera_pose_vis, rescale, selected_points, raw_input, canvas]
801
  )
802
+
803
+
804
 
805
 
806
  gr.Markdown(article)
objctrl_2_5d/utils/ui_utils.py CHANGED
@@ -9,6 +9,7 @@ from objctrl_2_5d.utils.vis_camera import vis_camera_rescale
9
  from objctrl_2_5d.utils.objmask_util import trajectory_to_camera_poses_v1
10
  from objctrl_2_5d.utils.customized_cam import rotation, clockwise, pan_and_zoom
11
 
 
12
 
13
  zc_threshold = 0.2
14
  depth_scale_ = 5.2
@@ -29,8 +30,6 @@ def process_image(raw_image):
29
 
30
  image, points = raw_image['image'], raw_image['points']
31
 
32
- print(points)
33
-
34
  try:
35
  assert(len(points)) == 1, "Please select only one point"
36
  [x1, y1, _, x2, y2, _] = points[0]
@@ -88,7 +87,10 @@ def get_points(img,
88
  # draw an arrow from handle point to target point
89
  # if len(points) == idx + 1:
90
  if idx > 0:
91
- cv2.arrowedLine(img, points[idx-1], points[idx], (255, 255, 255), 4, tipLength=0.5)
 
 
 
92
  # points = []
93
 
94
  return img if isinstance(img, np.ndarray) else np.array(img), sel_pix
@@ -113,6 +115,9 @@ def interpolate_points(points, num_points):
113
 
114
  def traj2cam(traj, depth, rescale):
115
 
 
 
 
116
  traj = np.array(traj)
117
  trajectory = interpolate_points(traj, num_frames)
118
 
@@ -148,13 +153,13 @@ def traj2cam(traj, depth, rescale):
148
  RTs = traj_w2c[:, :3]
149
  fig = vis_camera_rescale(RTs)
150
 
151
- return RTs, fig, rescale
152
 
153
  def get_rotate_cam(angle, depth):
154
  # mean_depth = np.mean(depth * mask)
155
  center_h_margin, center_w_margin = center_margin, center_margin
156
  depth_center = np.mean(depth[height//2-center_h_margin:height//2+center_h_margin, width//2-center_w_margin:width//2+center_w_margin])
157
- print(f'rotate depth_center: {depth_center}')
158
 
159
  RTs = rotation(num_frames, angle, depth_center, depth_center)
160
  fig = vis_camera_rescale(RTs)
@@ -162,47 +167,128 @@ def get_rotate_cam(angle, depth):
162
  return RTs, fig
163
 
164
  def get_clockwise_cam(angle, depth, mask):
165
- mask = mask.astype(np.float32) # [0, 1]
166
- mean_depth = np.mean(depth * mask)
167
  # center_h_margin, center_w_margin = center_margin, center_margin
168
  # depth_center = np.mean(depth[height//2-center_h_margin:height//2+center_h_margin, width//2-center_w_margin:width//2+center_w_margin])
169
 
170
  RTs = clockwise(angle, num_frames)
171
 
172
- RTs[:, -1, -1] = mean_depth
173
  fig = vis_camera_rescale(RTs)
174
 
175
  return RTs, fig
176
 
177
  def get_translate_cam(Tx, Ty, Tz, depth, mask, speed):
178
- mask = mask.astype(np.float32) # [0, 1]
179
 
180
- mean_depth = np.mean(depth * mask)
181
 
182
  T = np.array([Tx, Ty, Tz])
183
  T = T.reshape(3, 1)
184
  T = T[None, ...].repeat(num_frames, axis=0)
185
 
186
  RTs = pan_and_zoom(T, speed, n=num_frames)
187
- RTs[:, -1, -1] += mean_depth
188
  fig = vis_camera_rescale(RTs)
189
 
190
  return RTs, fig
191
 
 
192
  def get_camera_pose(camera_mode):
193
- def trigger_camera_pose(camera_option, selected_points, depth, mask, rescale, angle, Tx, Ty, Tz, speed):
194
- if camera_option == camera_mode[0]: # traj2cam
195
- RTs, fig, rescale = traj2cam(selected_points, depth, rescale)
196
- elif camera_option == camera_mode[1]: # rotate
197
- RTs, fig = get_rotate_cam(angle, depth)
198
- rescale = 0.0
199
- elif camera_option == camera_mode[2]: # clockwise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  RTs, fig = get_clockwise_cam(angle, depth, mask)
201
- rescale = 0.0
202
- elif camera_option == camera_mode[3]: # translate
203
- RTs, fig = get_translate_cam(Tx, Ty, Tz, depth, mask, speed)
204
- rescale = 0.0
205
 
 
206
  return RTs, fig, rescale
207
 
208
  return trigger_camera_pose
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from objctrl_2_5d.utils.objmask_util import trajectory_to_camera_poses_v1
10
  from objctrl_2_5d.utils.customized_cam import rotation, clockwise, pan_and_zoom
11
 
12
+ CAMERA_MODE = ["None", "ZoomIn", "ZoomOut", "PanRight", "PanLeft", "TiltUp", "TiltDown", "ClockWise", "Anti-CW", "Rotate60"]
13
 
14
  zc_threshold = 0.2
15
  depth_scale_ = 5.2
 
30
 
31
  image, points = raw_image['image'], raw_image['points']
32
 
 
 
33
  try:
34
  assert(len(points)) == 1, "Please select only one point"
35
  [x1, y1, _, x2, y2, _] = points[0]
 
87
  # draw an arrow from handle point to target point
88
  # if len(points) == idx + 1:
89
  if idx > 0:
90
+ line_length = np.sqrt((points[idx][0] - points[idx-1][0])**2 + (points[idx][1] - points[idx-1][1])**2)
91
+ arrow_head_length = 10
92
+ tip_length = arrow_head_length / line_length
93
+ cv2.arrowedLine(img, points[idx-1], points[idx], (0, 255, 0), 4, tipLength=tip_length)
94
  # points = []
95
 
96
  return img if isinstance(img, np.ndarray) else np.array(img), sel_pix
 
115
 
116
  def traj2cam(traj, depth, rescale):
117
 
118
+ if len(traj) == 0:
119
+ return None, None, 0.0, gr.update(value=CAMERA_MODE[0])
120
+
121
  traj = np.array(traj)
122
  trajectory = interpolate_points(traj, num_frames)
123
 
 
153
  RTs = traj_w2c[:, :3]
154
  fig = vis_camera_rescale(RTs)
155
 
156
+ return RTs, fig, rescale, gr.update(value=CAMERA_MODE[0])
157
 
158
  def get_rotate_cam(angle, depth):
159
  # mean_depth = np.mean(depth * mask)
160
  center_h_margin, center_w_margin = center_margin, center_margin
161
  depth_center = np.mean(depth[height//2-center_h_margin:height//2+center_h_margin, width//2-center_w_margin:width//2+center_w_margin])
162
+ # print(f'rotate depth_center: {depth_center}')
163
 
164
  RTs = rotation(num_frames, angle, depth_center, depth_center)
165
  fig = vis_camera_rescale(RTs)
 
167
  return RTs, fig
168
 
169
  def get_clockwise_cam(angle, depth, mask):
170
+ # mask = mask.astype(np.float32) # [0, 1]
171
+ # mean_depth = np.mean(depth * mask)
172
  # center_h_margin, center_w_margin = center_margin, center_margin
173
  # depth_center = np.mean(depth[height//2-center_h_margin:height//2+center_h_margin, width//2-center_w_margin:width//2+center_w_margin])
174
 
175
  RTs = clockwise(angle, num_frames)
176
 
177
+ # RTs[:, -1, -1] = mean_depth
178
  fig = vis_camera_rescale(RTs)
179
 
180
  return RTs, fig
181
 
182
  def get_translate_cam(Tx, Ty, Tz, depth, mask, speed):
183
+ # mask = mask.astype(np.float32) # [0, 1]
184
 
185
+ # mean_depth = np.mean(depth * mask)
186
 
187
  T = np.array([Tx, Ty, Tz])
188
  T = T.reshape(3, 1)
189
  T = T[None, ...].repeat(num_frames, axis=0)
190
 
191
  RTs = pan_and_zoom(T, speed, n=num_frames)
192
+ # RTs[:, -1, -1] += mean_depth
193
  fig = vis_camera_rescale(RTs)
194
 
195
  return RTs, fig
196
 
197
+
198
  def get_camera_pose(camera_mode):
199
+ # camera_mode = ["None", "ZoomIn", "ZoomOut", "PanLeft", "PanRight", "TiltUp", "TiltDown", "ClockWise", "Anti-CW", "Rotate60"]
200
+ def trigger_camera_pose(camera_option, depth, mask, rescale, angle, speed):
201
+ if camera_option == camera_mode[0]: # None
202
+ RTs = None
203
+ fig = None
204
+ elif camera_option == camera_mode[1]: # ZoomIn
205
+ RTs, fig = get_translate_cam(0, 0, -1, depth, mask, speed)
206
+
207
+ elif camera_option == camera_mode[2]: # ZoomOut
208
+ RTs, fig = get_translate_cam(0, 0, 1, depth, mask, speed)
209
+
210
+ elif camera_option == camera_mode[3]: # PanLeft
211
+ RTs, fig = get_translate_cam(-1, 0, 0, depth, mask, speed)
212
+
213
+ elif camera_option == camera_mode[4]: # PanRight
214
+ RTs, fig = get_translate_cam(1, 0, 0, depth, mask, speed)
215
+
216
+ elif camera_option == camera_mode[5]: # TiltUp
217
+ RTs, fig = get_translate_cam(0, 1, 0, depth, mask, speed)
218
+
219
+ elif camera_option == camera_mode[6]: # TiltDown
220
+ RTs, fig = get_translate_cam(0, -1, 0, depth, mask, speed)
221
+
222
+ elif camera_option == camera_mode[7]: # ClockWise
223
+ RTs, fig = get_clockwise_cam(-angle, depth, mask)
224
+
225
+ elif camera_option == camera_mode[8]: # Anti-CW
226
  RTs, fig = get_clockwise_cam(angle, depth, mask)
227
+
228
+ else: # Rotate60
229
+ RTs, fig = get_rotate_cam(angle, depth)
 
230
 
231
+ rescale = 0.0
232
  return RTs, fig, rescale
233
 
234
  return trigger_camera_pose
235
+
236
+ import os
237
+ from glob import glob
238
+ import json
239
+
240
+ def get_mid_params(raw_input, canvas, mask, selected_points, camera_option, bg_mode, shared_wapring_latents, generated_video):
241
+ output_dir = "./assets/examples"
242
+ os.makedirs(output_dir, exist_ok=True)
243
+
244
+ # folders = sorted(glob(output_dir + "/*"))
245
+ folders = os.listdir(output_dir)
246
+ folders = [int(folder) for folder in folders if os.path.isdir(os.path.join(output_dir, folder))]
247
+ num = sorted(folders)[-1] + 1 if folders else 0
248
+
249
+ fout = open(os.path.join(output_dir, f'examples.json'), 'a+')
250
+
251
+ cur_folder = os.path.join(output_dir, f'{num:05d}')
252
+ os.makedirs(cur_folder, exist_ok=True)
253
+
254
+ raw_image = raw_input['image']
255
+ raw_points = raw_input['points']
256
+ seg_image = canvas['image']
257
+ seg_points = canvas['points']
258
+
259
+ mask = Image.fromarray(mask)
260
+ mask_path = os.path.join(cur_folder, 'mask.png')
261
+ mask.save(mask_path)
262
+
263
+ raw_image_path = os.path.join(cur_folder, 'raw_image.png')
264
+ seg_image_path = os.path.join(cur_folder, 'seg_image.png')
265
+
266
+ raw_image.save(os.path.join(cur_folder, 'raw_image.png'))
267
+ seg_image.save(os.path.join(cur_folder, 'seg_image.png'))
268
+
269
+ gen_path = os.path.join(cur_folder, 'generated_video.mp4')
270
+ cmd = f"cp {generated_video} {gen_path}"
271
+ os.system(cmd)
272
+
273
+ # data = [{'image': raw_image_path, 'points': raw_points},
274
+ # {'image': seg_image_path, 'points': seg_points},
275
+ # mask_path,
276
+ # str(selected_points),
277
+ # camera_option,
278
+ # bg_mode,
279
+ # gen_path]
280
+ data = {f'{num:05d}': [{'image': raw_image_path},
281
+ str(raw_points),
282
+ {'image': seg_image_path},
283
+ str(seg_points),
284
+ mask_path,
285
+ str(selected_points),
286
+ camera_option,
287
+ bg_mode,
288
+ shared_wapring_latents,
289
+ gen_path]}
290
+ fout.write(json.dumps(data) + '\n')
291
+
292
+ fout.close()
293
+
294
+