import os import torch import numpy as np import torch.nn.functional as F import cv2 import torchvision from PIL import Image from einops import rearrange import tempfile from objctrl_2_5d.utils.objmask_util import RT2Plucker, Unprojected, roll_with_ignore_multidim, dilate_mask_pytorch from objctrl_2_5d.utils.filter_utils import get_freq_filter, freq_mix_3d DEBUG = False if DEBUG: cur_OUTPUT_PATH = 'outputs/tmp' os.makedirs(cur_OUTPUT_PATH, exist_ok=True) # num_inference_steps=25 min_guidance_scale = 1.0 max_guidance_scale = 3.0 area_ratio = 0.3 depth_scale_ = 5.2 center_margin = 10 height, width = 320, 576 num_frames = 14 intrinsics = np.array([[float(width), float(width), float(width) / 2, float(height) / 2]]) intrinsics = np.repeat(intrinsics, num_frames, axis=0) # [n_frame, 4] fx = intrinsics[0, 0] / width fy = intrinsics[0, 1] / height cx = intrinsics[0, 2] / width cy = intrinsics[0, 3] / height down_scale = 8 H, W = height // down_scale, width // down_scale K = np.array([[width / down_scale, 0, W / 2], [0, width / down_scale, H / 2], [0, 0, 1]]) def run(pipeline, device): def run_objctrl_2_5d(condition_image, mask, depth, RTs, bg_mode, shared_wapring_latents, scale_wise_masks, rescale, seed, ds, dt, num_inference_steps=25): seed = int(seed) center_h_margin, center_w_margin = center_margin, center_margin depth_center = np.mean(depth[height//2-center_h_margin:height//2+center_h_margin, width//2-center_w_margin:width//2+center_w_margin]) if rescale > 0: depth_rescale = round(depth_scale_ * rescale / depth_center, 2) else: depth_rescale = 1.0 depth = depth * depth_rescale depth_down = F.interpolate(torch.tensor(depth).unsqueeze(0).unsqueeze(0), (H, W), mode='bilinear', align_corners=False).squeeze().numpy() # [H, W] ## latent generator = torch.Generator() generator.manual_seed(seed) latents_org = pipeline.prepare_latents( 1, 14, 8, height, width, pipeline.dtype, device, generator, None, ) latents_org = latents_org / pipeline.scheduler.init_noise_sigma cur_plucker_embedding, _, _ = RT2Plucker(RTs, RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W cur_plucker_embedding = cur_plucker_embedding.to(device) cur_plucker_embedding = cur_plucker_embedding[None, ...] # b 6 f h w cur_plucker_embedding = cur_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w cur_plucker_embedding = cur_plucker_embedding[:, :num_frames, ...] cur_pose_features = pipeline.pose_encoder(cur_plucker_embedding) # bg_mode = ["Fixed", "Reverse", "Free"] if bg_mode == "Fixed": fix_RTs = np.repeat(RTs[0][None, ...], num_frames, axis=0) # [n_frame, 4, 3] fix_plucker_embedding, _, _ = RT2Plucker(fix_RTs, num_frames, (height, width), fx, fy, cx, cy) # 6, V, H, W fix_plucker_embedding = fix_plucker_embedding.to(device) fix_plucker_embedding = fix_plucker_embedding[None, ...] # b 6 f h w fix_plucker_embedding = fix_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w fix_plucker_embedding = fix_plucker_embedding[:, :num_frames, ...] fix_pose_features = pipeline.pose_encoder(fix_plucker_embedding) elif bg_mode == "Reverse": bg_plucker_embedding, _, _ = RT2Plucker(RTs[::-1], RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W bg_plucker_embedding = bg_plucker_embedding.to(device) bg_plucker_embedding = bg_plucker_embedding[None, ...] # b 6 f h w bg_plucker_embedding = bg_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w bg_plucker_embedding = bg_plucker_embedding[:, :num_frames, ...] fix_pose_features = pipeline.pose_encoder(bg_plucker_embedding) else: fix_pose_features = None #### preparing mask mask = Image.fromarray(mask) mask.save(f'{cur_OUTPUT_PATH}/org_mask_big.png') mask = Image.open(f'{cur_OUTPUT_PATH}/org_mask_big.png') mask = mask.resize((W, H)) mask = np.array(mask).astype(np.float32) mask = np.expand_dims(mask, axis=-1) # visulize mask if DEBUG: mask_sum_vis = mask[..., 0] mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8) mask_sum_vis = Image.fromarray(mask_sum_vis) mask_sum_vis.save(f'{cur_OUTPUT_PATH}/org_mask.png') try: warped_masks = Unprojected(mask, depth_down, RTs, H=H, W=W, K=K) warped_masks.insert(0, mask) except: # mask to bbox print(f'!!! Mask is too small to warp; mask to bbox') mask = mask[:, :, 0] coords = cv2.findNonZero(mask) x, y, w, h = cv2.boundingRect(coords) # mask[y:y+h, x:x+w] = 1.0 center_x, center_y = x + w // 2, y + h // 2 center_z = depth_down[center_y, center_x] # RTs [n_frame, 3, 4] to [n_frame, 4, 4] , add [0, 0, 0, 1] RTs = np.concatenate([RTs, np.array([[[0, 0, 0, 1]]] * num_frames)], axis=1) # RTs: world to camera P0 = np.array([center_x, center_y, 1]) Pc0 = np.linalg.inv(K) @ P0 * center_z pw = np.linalg.inv(RTs[0]) @ np.array([Pc0[0], Pc0[1], center_z, 1]) # [4] P = [np.array([center_x, center_y])] for i in range(1, num_frames): Pci = RTs[i] @ pw Pi = K @ Pci[:3] / Pci[2] P.append(Pi[:2]) warped_masks = [mask] for i in range(1, num_frames): shift_x = int(round(P[i][0] - P[0][0])) shift_y = int(round(P[i][1] - P[0][1])) cur_mask = roll_with_ignore_multidim(mask, [shift_y, shift_x]) warped_masks.append(cur_mask) warped_masks = [v[..., None] for v in warped_masks] warped_masks = np.stack(warped_masks, axis=0) # [f, h, w] warped_masks = np.repeat(warped_masks, 3, axis=-1) # [f, h, w, 3] mask_sum = np.sum(warped_masks, axis=0, keepdims=True) # [1, H, W, 3] mask_sum[mask_sum > 1.0] = 1.0 mask_sum = mask_sum[0,:,:, 0] if DEBUG: ## visulize warp mask warp_masks_vis = torch.tensor(warped_masks) warp_masks_vis = (warp_masks_vis * 255.0).to(torch.uint8) torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warped_masks.mp4', warp_masks_vis, fps=10, video_codec='h264', options={'crf': '10'}) # visulize mask mask_sum_vis = mask_sum mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8) mask_sum_vis = Image.fromarray(mask_sum_vis) mask_sum_vis.save(f'{cur_OUTPUT_PATH}/merged_mask.png') if scale_wise_masks: min_area = H * W * area_ratio # cal in downscale non_zero_len = mask_sum.sum() print(f'non_zero_len: {non_zero_len}, min_area: {min_area}') if non_zero_len > min_area: kernel_sizes = [1, 1, 1, 3] elif non_zero_len > min_area * 0.5: kernel_sizes = [3, 1, 1, 5] else: kernel_sizes = [5, 3, 3, 7] else: kernel_sizes = [1, 1, 1, 1] mask = torch.from_numpy(mask_sum) # [h, w] mask = mask[None, None, ...] # [1, 1, h, w] mask = F.interpolate(mask, (height, width), mode='bilinear', align_corners=False) # [1, 1, H, W] # mask = mask.repeat(1, num_frames, 1, 1) # [1, f, H, W] mask = mask.to(pipeline.dtype).to(device) ##### Mask End ###### ### Got blending pose features Start ### pose_features = [] for i in range(0, len(cur_pose_features)): kernel_size = kernel_sizes[i] h, w = cur_pose_features[i].shape[-2:] if fix_pose_features is None: pose_features.append(torch.zeros_like(cur_pose_features[i])) else: pose_features.append(fix_pose_features[i]) cur_mask = F.interpolate(mask, (h, w), mode='bilinear', align_corners=False) cur_mask = dilate_mask_pytorch(cur_mask, kernel_size=kernel_size) # [1, 1, H, W] cur_mask = cur_mask.repeat(1, num_frames, 1, 1) # [1, f, H, W] if DEBUG: # visulize mask mask_vis = cur_mask[0, 0].cpu().numpy() * 255.0 mask_vis = Image.fromarray(mask_vis.astype(np.uint8)) mask_vis.save(f'{cur_OUTPUT_PATH}/mask_k{kernel_size}_scale{i}.png') cur_mask = cur_mask[None, ...] # [1, 1, f, H, W] pose_features[-1] = cur_pose_features[i] * cur_mask + pose_features[-1] * (1 - cur_mask) ### Got blending pose features End ### ##### Warp Noise Start ###### if shared_wapring_latents: noise = latents_org[0, 0].data.cpu().numpy().copy() #[14, 4, 40, 72] noise = np.transpose(noise, (1, 2, 0)) # [40, 72, 4] try: warp_noise = Unprojected(noise, depth_down, RTs, H=H, W=W, K=K) warp_noise.insert(0, noise) except: print(f'!!! Noise is too small to warp; mask to bbox') warp_noise = [noise] for i in range(1, num_frames): shift_x = int(round(P[i][0] - P[0][0])) shift_y = int(round(P[i][1] - P[0][1])) cur_noise= roll_with_ignore_multidim(noise, [shift_y, shift_x]) warp_noise.append(cur_noise) warp_noise = np.stack(warp_noise, axis=0) # [f, h, w, 4] if DEBUG: ## visulize warp noise warp_noise_vis = torch.tensor(warp_noise)[..., :3] * torch.tensor(warped_masks) warp_noise_vis = (warp_noise_vis - warp_noise_vis.min()) / (warp_noise_vis.max() - warp_noise_vis.min()) warp_noise_vis = (warp_noise_vis * 255.0).to(torch.uint8) torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warp_noise.mp4', warp_noise_vis, fps=10, video_codec='h264', options={'crf': '10'}) warp_latents = torch.tensor(warp_noise).permute(0, 3, 1, 2).to(latents_org.device).to(latents_org.dtype) # [frame, 4, H, W] warp_latents = warp_latents.unsqueeze(0) # [1, frame, 4, H, W] warped_masks = torch.tensor(warped_masks).permute(0, 3, 1, 2).unsqueeze(0) # [1, frame, 3, H, W] mask_extend = torch.concat([warped_masks, warped_masks[:,:,0:1]], dim=2) # [1, frame, 4, H, W] mask_extend = mask_extend.to(latents_org.device).to(latents_org.dtype) warp_latents = warp_latents * mask_extend + latents_org * (1 - mask_extend) warp_latents = warp_latents.permute(0, 2, 1, 3, 4) random_noise = latents_org.clone().permute(0, 2, 1, 3, 4) filter_shape = warp_latents.shape freq_filter = get_freq_filter( filter_shape, device = device, filter_type='butterworth', n=4, d_s=ds, d_t=dt ) warp_latents = freq_mix_3d(warp_latents, random_noise, freq_filter) warp_latents = warp_latents.permute(0, 2, 1, 3, 4) else: warp_latents = latents_org.clone() generator.manual_seed(42) with torch.no_grad(): result = pipeline( image=condition_image, pose_embedding=cur_plucker_embedding, height=height, width=width, num_frames=num_frames, num_inference_steps=num_inference_steps, min_guidance_scale=min_guidance_scale, max_guidance_scale=max_guidance_scale, do_image_process=True, generator=generator, output_type='pt', pose_features= pose_features, latents = warp_latents ).frames[0].cpu() #[f, c, h, w] result = rearrange(result, 'f c h w -> f h w c') result = (result * 255.0).to(torch.uint8) video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name torchvision.io.write_video(video_path, result, fps=10, video_codec='h264', options={'crf': '8'}) return video_path return run_objctrl_2_5d