import os import PIL.Image import numpy as np import torch import torchvision from torchvision.transforms import Resize, InterpolationMode import imageio from einops import rearrange import cv2 from PIL import Image import decord from controlnet_aux import OpenposeDetector apply_openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1): vr = decord.VideoReader(video_path) initial_fps = vr.get_avg_fps() if output_fps == -1: output_fps = int(initial_fps) if end_t == -1: end_t = len(vr) / initial_fps else: end_t = min(len(vr) / initial_fps, end_t) assert 0 <= start_t < end_t assert output_fps > 0 start_f_ind = int(start_t * initial_fps) end_f_ind = int(end_t * initial_fps) num_f = int((end_t - start_t) * output_fps) sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int) video = vr.get_batch(sample_idx) if torch.is_tensor(video): video = video.detach().cpu().numpy() else: video = video.asnumpy() _, h, w, _ = video.shape video = rearrange(video, "f h w c -> f c h w") video = torch.Tensor(video).to(device).to(dtype) # Use max if you want the larger side to be equal to resolution (e.g. 512) # k = float(resolution) / min(h, w) k = float(resolution) / max(h, w) h *= k w *= k h = int(np.round(h / 64.0)) * 64 w = int(np.round(w / 64.0)) * 64 video = Resize((h, w), interpolation=InterpolationMode.BILINEAR, antialias=True)(video) if normalize: video = video / 127.5 - 1.0 return video, output_fps def pre_process_pose(input_video, apply_pose_detect: bool = True): detected_maps = [] for frame in input_video: img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8) if apply_pose_detect: detected_map, _ = apply_openpose(img) else: detected_map = img H, W, C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) detected_maps.append(detected_map[None]) detected_maps = np.concatenate(detected_maps) control = torch.from_numpy(detected_maps.copy()).float() / 255.0 return rearrange(control, 'f h w c -> f c h w') def create_gif(frames, fps, rescale=False, path=None, watermark=None): if path is None: dir = "temporal" os.makedirs(dir, exist_ok=True) path = os.path.join(dir, 'canny_db.gif') outputs = [] for i, x in enumerate(frames): x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4) if rescale: x = (x + 1.0) / 2.0 # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) outputs.append(x) # imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x) imageio.mimsave(path, outputs, fps=fps) return path def post_process_gif(list_of_results, image_resolution): output_file = "/tmp/ddxk.gif" imageio.mimsave(output_file, list_of_results, fps=4) return output_file