{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import cv2\n", "import torch\n", "import numpy as np\n", "import PIL\n", "from PIL import Image\n", "from einops import rearrange\n", "from video_vae import CausalVideoVAELossWrapper\n", "from torchvision import transforms as pth_transforms\n", "from torchvision.transforms.functional import InterpolationMode\n", "from IPython.display import Image as ipython_image\n", "from diffusers.utils import load_image, export_to_video, export_to_gif\n", "from IPython.display import HTML" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_path = \"pyramid-flow-miniflux/causal_video_vae\" # The video-vae checkpoint dir\n", "model_dtype = 'bf16'\n", "\n", "device_id = 3\n", "torch.cuda.set_device(device_id)\n", "\n", "model = CausalVideoVAELossWrapper(\n", " model_path,\n", " model_dtype,\n", " interpolate=False, \n", " add_discriminator=False,\n", ")\n", "model = model.to(\"cuda\")\n", "\n", "if model_dtype == \"bf16\":\n", " torch_dtype = torch.bfloat16 \n", "elif model_dtype == \"fp16\":\n", " torch_dtype = torch.float16\n", "else:\n", " torch_dtype = torch.float32\n", "\n", "def image_transform(images, resize_width, resize_height):\n", " transform_list = pth_transforms.Compose([\n", " pth_transforms.Resize((resize_height, resize_width), InterpolationMode.BICUBIC, antialias=True),\n", " pth_transforms.ToTensor(),\n", " pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", " ])\n", " return torch.stack([transform_list(image) for image in images])\n", "\n", "\n", "def get_transform(width, height, new_width=None, new_height=None, resize=False,):\n", " transform_list = []\n", "\n", " if resize:\n", " if new_width is None:\n", " new_width = width // 8 * 8\n", " if new_height is None:\n", " new_height = height // 8 * 8\n", " transform_list.append(pth_transforms.Resize((new_height, new_width), InterpolationMode.BICUBIC, antialias=True))\n", " \n", " transform_list.extend([\n", " pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", " ])\n", " transform_list = pth_transforms.Compose(transform_list)\n", "\n", " return transform_list\n", "\n", "\n", "def load_video_and_transform(video_path, frame_number, new_width=None, new_height=None, max_frames=600, sample_fps=24, resize=False):\n", " try:\n", " video_capture = cv2.VideoCapture(video_path)\n", " fps = video_capture.get(cv2.CAP_PROP_FPS)\n", " frames = []\n", " pil_frames = []\n", " while True:\n", " flag, frame = video_capture.read()\n", " if not flag:\n", " break\n", " \n", " pil_frames.append(np.ascontiguousarray(frame[:, :, ::-1]))\n", " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", " frame = torch.from_numpy(frame)\n", " frame = frame.permute(2, 0, 1)\n", " frames.append(frame)\n", " if len(frames) >= max_frames:\n", " break\n", "\n", " video_capture.release()\n", " interval = max(int(fps / sample_fps), 1)\n", " pil_frames = pil_frames[::interval][:frame_number]\n", " frames = frames[::interval][:frame_number]\n", " frames = torch.stack(frames).float() / 255\n", " width = frames.shape[-1]\n", " height = frames.shape[-2]\n", " video_transform = get_transform(width, height, new_width, new_height, resize=resize)\n", " frames = video_transform(frames)\n", " pil_frames = [Image.fromarray(frame).convert(\"RGB\") for frame in pil_frames]\n", "\n", " if resize:\n", " if new_width is None:\n", " new_width = width // 32 * 32\n", " if new_height is None:\n", " new_height = height // 32 * 32\n", " pil_frames = [frame.resize((new_width or width, new_height or height), PIL.Image.BICUBIC) for frame in pil_frames]\n", " return frames, pil_frames\n", " except Exception:\n", " return None\n", "\n", "\n", "def show_video(ori_path, rec_path, width=\"100%\"):\n", " html = ''\n", " if ori_path is not None:\n", " html += f\"\"\"\n", " \"\"\"\n", " \n", " html += f\"\"\"\n", " \"\"\"\n", " return HTML(html)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Image Reconstruction" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_path = 'image_path'\n", "\n", "image = Image.open(image_path).convert(\"RGB\")\n", "resize_width = image.width // 8 * 8\n", "resize_height = image.height // 8 * 8\n", "input_image_tensor = image_transform([image], resize_width, resize_height)\n", "input_image_tensor = input_image_tensor.permute(1, 0, 2, 3).unsqueeze(0)\n", "\n", "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):\n", " latent = model.encode_latent(input_image_tensor.to(\"cuda\"), sample=True)\n", " rec_images = model.decode_latent(latent)\n", "\n", "display(image)\n", "display(rec_images[0])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Video Reconstruction" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "video_path = 'video_path'\n", "\n", "frame_number = 57 # x*8 + 1\n", "width = 640\n", "height = 384\n", "\n", "video_frames_tensor, pil_video_frames = load_video_and_transform(video_path, frame_number, new_width=width, new_height=height, resize=True)\n", "video_frames_tensor = video_frames_tensor.permute(1, 0, 2, 3).unsqueeze(0)\n", "print(video_frames_tensor.shape)\n", "\n", "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):\n", " latent = model.encode_latent(video_frames_tensor.to(\"cuda\"), sample=False, window_size=8, temporal_chunk=True)\n", " rec_frames = model.decode_latent(latent.float(), window_size=2, temporal_chunk=True)\n", "\n", "export_to_video(pil_video_frames, './ori_video.mp4', fps=24)\n", "export_to_video(rec_frames, \"./rec_video.mp4\", fps=24)\n", "show_video('./ori_video.mp4', \"./rec_video.mp4\", \"60%\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }