{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import cv2\n",
"import torch\n",
"import numpy as np\n",
"import PIL\n",
"from PIL import Image\n",
"from einops import rearrange\n",
"from video_vae import CausalVideoVAELossWrapper\n",
"from torchvision import transforms as pth_transforms\n",
"from torchvision.transforms.functional import InterpolationMode\n",
"from IPython.display import Image as ipython_image\n",
"from diffusers.utils import load_image, export_to_video, export_to_gif\n",
"from IPython.display import HTML"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_path = \"pyramid-flow-miniflux/causal_video_vae\" # The video-vae checkpoint dir\n",
"model_dtype = 'bf16'\n",
"\n",
"device_id = 3\n",
"torch.cuda.set_device(device_id)\n",
"\n",
"model = CausalVideoVAELossWrapper(\n",
" model_path,\n",
" model_dtype,\n",
" interpolate=False, \n",
" add_discriminator=False,\n",
")\n",
"model = model.to(\"cuda\")\n",
"\n",
"if model_dtype == \"bf16\":\n",
" torch_dtype = torch.bfloat16 \n",
"elif model_dtype == \"fp16\":\n",
" torch_dtype = torch.float16\n",
"else:\n",
" torch_dtype = torch.float32\n",
"\n",
"def image_transform(images, resize_width, resize_height):\n",
" transform_list = pth_transforms.Compose([\n",
" pth_transforms.Resize((resize_height, resize_width), InterpolationMode.BICUBIC, antialias=True),\n",
" pth_transforms.ToTensor(),\n",
" pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
" ])\n",
" return torch.stack([transform_list(image) for image in images])\n",
"\n",
"\n",
"def get_transform(width, height, new_width=None, new_height=None, resize=False,):\n",
" transform_list = []\n",
"\n",
" if resize:\n",
" if new_width is None:\n",
" new_width = width // 8 * 8\n",
" if new_height is None:\n",
" new_height = height // 8 * 8\n",
" transform_list.append(pth_transforms.Resize((new_height, new_width), InterpolationMode.BICUBIC, antialias=True))\n",
" \n",
" transform_list.extend([\n",
" pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
" ])\n",
" transform_list = pth_transforms.Compose(transform_list)\n",
"\n",
" return transform_list\n",
"\n",
"\n",
"def load_video_and_transform(video_path, frame_number, new_width=None, new_height=None, max_frames=600, sample_fps=24, resize=False):\n",
" try:\n",
" video_capture = cv2.VideoCapture(video_path)\n",
" fps = video_capture.get(cv2.CAP_PROP_FPS)\n",
" frames = []\n",
" pil_frames = []\n",
" while True:\n",
" flag, frame = video_capture.read()\n",
" if not flag:\n",
" break\n",
" \n",
" pil_frames.append(np.ascontiguousarray(frame[:, :, ::-1]))\n",
" frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
" frame = torch.from_numpy(frame)\n",
" frame = frame.permute(2, 0, 1)\n",
" frames.append(frame)\n",
" if len(frames) >= max_frames:\n",
" break\n",
"\n",
" video_capture.release()\n",
" interval = max(int(fps / sample_fps), 1)\n",
" pil_frames = pil_frames[::interval][:frame_number]\n",
" frames = frames[::interval][:frame_number]\n",
" frames = torch.stack(frames).float() / 255\n",
" width = frames.shape[-1]\n",
" height = frames.shape[-2]\n",
" video_transform = get_transform(width, height, new_width, new_height, resize=resize)\n",
" frames = video_transform(frames)\n",
" pil_frames = [Image.fromarray(frame).convert(\"RGB\") for frame in pil_frames]\n",
"\n",
" if resize:\n",
" if new_width is None:\n",
" new_width = width // 32 * 32\n",
" if new_height is None:\n",
" new_height = height // 32 * 32\n",
" pil_frames = [frame.resize((new_width or width, new_height or height), PIL.Image.BICUBIC) for frame in pil_frames]\n",
" return frames, pil_frames\n",
" except Exception:\n",
" return None\n",
"\n",
"\n",
"def show_video(ori_path, rec_path, width=\"100%\"):\n",
" html = ''\n",
" if ori_path is not None:\n",
" html += f\"\"\"\n",
" \"\"\"\n",
" \n",
" html += f\"\"\"\n",
" \"\"\"\n",
" return HTML(html)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Image Reconstruction"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image_path = 'image_path'\n",
"\n",
"image = Image.open(image_path).convert(\"RGB\")\n",
"resize_width = image.width // 8 * 8\n",
"resize_height = image.height // 8 * 8\n",
"input_image_tensor = image_transform([image], resize_width, resize_height)\n",
"input_image_tensor = input_image_tensor.permute(1, 0, 2, 3).unsqueeze(0)\n",
"\n",
"with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):\n",
" latent = model.encode_latent(input_image_tensor.to(\"cuda\"), sample=True)\n",
" rec_images = model.decode_latent(latent)\n",
"\n",
"display(image)\n",
"display(rec_images[0])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Video Reconstruction"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"video_path = 'video_path'\n",
"\n",
"frame_number = 57 # x*8 + 1\n",
"width = 640\n",
"height = 384\n",
"\n",
"video_frames_tensor, pil_video_frames = load_video_and_transform(video_path, frame_number, new_width=width, new_height=height, resize=True)\n",
"video_frames_tensor = video_frames_tensor.permute(1, 0, 2, 3).unsqueeze(0)\n",
"print(video_frames_tensor.shape)\n",
"\n",
"with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):\n",
" latent = model.encode_latent(video_frames_tensor.to(\"cuda\"), sample=False, window_size=8, temporal_chunk=True)\n",
" rec_frames = model.decode_latent(latent.float(), window_size=2, temporal_chunk=True)\n",
"\n",
"export_to_video(pil_video_frames, './ori_video.mp4', fps=24)\n",
"export_to_video(rec_frames, \"./rec_video.mp4\", fps=24)\n",
"show_video('./ori_video.mp4', \"./rec_video.mp4\", \"60%\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}