import os import subprocess import tempfile import cv2 import torch from PIL import Image from typing import Mapping from einops import rearrange import numpy as np import torchvision.transforms.functional as transforms_F from video_to_video.utils.logger import get_logger logger = get_logger() def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) video = video.mul_(std).add_(mean) video.clamp_(0, 1) video = video * 255.0 images = rearrange(video, 'b c f h w -> b f h w c')[0] return images def preprocess(input_frames): out_frame_list = [] for pointer in range(len(input_frames)): frame = input_frames[pointer] frame = frame[:, :, ::-1] frame = Image.fromarray(frame.astype('uint8')).convert('RGB') frame = transforms_F.to_tensor(frame) out_frame_list.append(frame) out_frames = torch.stack(out_frame_list, dim=0) out_frames.clamp_(0, 1) mean = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1) std = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1) out_frames.sub_(mean.view(1, -1, 1, 1)).div_(std.view(1, -1, 1, 1)) return out_frames def adjust_resolution(h, w, up_scale): if h*up_scale < 720: up_s = 720/h target_h = int(up_s*h//2*2) target_w = int(up_s*w//2*2) elif h*w*up_scale*up_scale > 1280*2048: up_s = np.sqrt(1280*2048/(h*w)) target_h = int(up_s*h//2*2) target_w = int(up_s*w//2*2) else: target_h = int(up_scale*h//2*2) target_w = int(up_scale*w//2*2) return (target_h, target_w) def make_mask_cond(in_f_num, interp_f_num): mask_cond = [] interp_cond = [-1 for _ in range(interp_f_num)] for i in range(in_f_num): mask_cond.append(i) if i != in_f_num - 1: mask_cond += interp_cond return mask_cond def load_video(vid_path): capture = cv2.VideoCapture(vid_path) _fps = capture.get(cv2.CAP_PROP_FPS) _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT) pointer = 0 frame_list = [] stride = 1 while len(frame_list) < _total_frame_num: ret, frame = capture.read() pointer += 1 if (not ret) or (frame is None): break if pointer >= _total_frame_num + 1: break if pointer % stride == 0: frame_list.append(frame) capture.release() return frame_list, _fps def save_video(video, save_dir, file_name, fps=16.0): output_path = os.path.join(save_dir, file_name) images = [(img.numpy()).astype('uint8') for img in video] temp_dir = tempfile.mkdtemp() for fid, frame in enumerate(images): tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1)) cv2.imwrite(tpth, frame[:, :, ::-1]) tmp_path = os.path.join(save_dir, 'tmp.mp4') cmd = f'ffmpeg -y -f image2 -framerate {fps} -i {temp_dir}/%06d.png \ -vcodec libx264 -preset ultrafast -crf 0 -pix_fmt yuv420p {tmp_path}' status, output = subprocess.getstatusoutput(cmd) if status != 0: logger.error('Save Video Error with {}'.format(output)) os.system(f'rm -rf {temp_dir}') os.rename(tmp_path, output_path) def collate_fn(data, device): """Prepare the input just before the forward function. This method will move the tensors to the right device. Usually this method does not need to be overridden. Args: data: The data out of the dataloader. device: The device to move data to. Returns: The processed data. """ from torch.utils.data.dataloader import default_collate def get_class_name(obj): return obj.__class__.__name__ if isinstance(data, dict) or isinstance(data, Mapping): return type(data)({ k: collate_fn(v, device) if k != 'img_metas' else v for k, v in data.items() }) elif isinstance(data, (tuple, list)): if 0 == len(data): return torch.Tensor([]) if isinstance(data[0], (int, float)): return default_collate(data).to(device) else: return type(data)(collate_fn(v, device) for v in data) elif isinstance(data, np.ndarray): if data.dtype.type is np.str_: return data else: return collate_fn(torch.from_numpy(data), device) elif isinstance(data, torch.Tensor): return data.to(device) elif isinstance(data, (bytes, str, int, float, bool, type(None))): return data else: raise ValueError(f'Unsupported data type {type(data)}')