xierui.0097 commited on
Commit
f0e9666
·
1 Parent(s): 470b11c

Add application file

Browse files
Files changed (39) hide show
  1. README.md +106 -13
  2. __pycache__/inference_utils.cpython-39.pyc +0 -0
  3. inference_utils.py +148 -0
  4. requirements.txt +15 -0
  5. video_super_resolution/__pycache__/color_fix.cpython-39.pyc +0 -0
  6. video_super_resolution/color_fix.py +122 -0
  7. video_super_resolution/dataset.py +113 -0
  8. video_super_resolution/scripts/inference_sr.py +140 -0
  9. video_super_resolution/scripts/inference_sr.sh +56 -0
  10. video_to_video/__init__.py +0 -0
  11. video_to_video/__pycache__/__init__.cpython-39.pyc +0 -0
  12. video_to_video/__pycache__/video_to_video_model.cpython-39.pyc +0 -0
  13. video_to_video/diffusion/__init__.py +0 -0
  14. video_to_video/diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  15. video_to_video/diffusion/__pycache__/diffusion_sdedit.cpython-39.pyc +0 -0
  16. video_to_video/diffusion/__pycache__/schedules_sdedit.cpython-39.pyc +0 -0
  17. video_to_video/diffusion/__pycache__/solvers_sdedit.cpython-39.pyc +0 -0
  18. video_to_video/diffusion/diffusion_sdedit.py +443 -0
  19. video_to_video/diffusion/schedules_sdedit.py +85 -0
  20. video_to_video/diffusion/solvers_sdedit.py +204 -0
  21. video_to_video/modules/__init__.py +3 -0
  22. video_to_video/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  23. video_to_video/modules/__pycache__/embedder.cpython-39.pyc +0 -0
  24. video_to_video/modules/__pycache__/t5.cpython-39.pyc +0 -0
  25. video_to_video/modules/__pycache__/unet_v2v.cpython-39.pyc +0 -0
  26. video_to_video/modules/__pycache__/unet_v2v_LocalConv.cpython-39.pyc +0 -0
  27. video_to_video/modules/__pycache__/unet_v2v_deform.cpython-39.pyc +0 -0
  28. video_to_video/modules/embedder.py +75 -0
  29. video_to_video/modules/t5.py +335 -0
  30. video_to_video/modules/unet_v2v.py +2332 -0
  31. video_to_video/utils/__init__.py +0 -0
  32. video_to_video/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  33. video_to_video/utils/__pycache__/config.cpython-39.pyc +0 -0
  34. video_to_video/utils/__pycache__/logger.cpython-39.pyc +0 -0
  35. video_to_video/utils/__pycache__/seed.cpython-39.pyc +0 -0
  36. video_to_video/utils/config.py +169 -0
  37. video_to_video/utils/logger.py +94 -0
  38. video_to_video/utils/seed.py +14 -0
  39. video_to_video/video_to_video_model.py +210 -0
README.md CHANGED
@@ -1,13 +1,106 @@
1
- ---
2
- title: STAR
3
- emoji: 👁
4
- colorFrom: purple
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.6.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: 'STAR: Spatial-Temporal Augmentation with Text-to-Video Model'
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>
3
+ STAR: Spatial-Temporal Augmentation with Text-to-Video Models for Real-World Video Super-Resolution
4
+ </h1>
5
+ <div>
6
+ <a href='https://github.com/CSRuiXie' target='_blank'>Rui Xie<sup>1*</sup></a>,&emsp;
7
+ <a href='https://github.com/yhliu04' target='_blank'>Yinhong Liu<sup>1*</sup></a>,&emsp;
8
+ <a href='https://scholar.google.com/citations?user=Uhp3JKgAAAAJ&hl=zh-CN&oi=sra' target='_blank'>Chen Zhao<sup>1</sup></a>,&emsp;
9
+ <a href='https://scholar.google.com/citations?hl=zh-CN&user=yWq1Fd4AAAAJ' target='_blank'>Penghao Zhou<sup>2</sup></a>,&emsp;
10
+ <a href='https://scholar.google.com/citations?hl=zh-CN&user=Ds5wwRoAAAAJ' target='_blank'>Zhenheng Yang<sup>2</sup></a><br>
11
+ <a href='https://scholar.google.com/citations?hl=zh-CN&user=w03CHFwAAAAJ' target='_blank'>Jun Zhou<sup>3</sup></a>,&emsp;
12
+ <a href='https://cszn.github.io/' target='_blank'>Kai Zhang<sup>1</sup></a>,&emsp;
13
+ <a href='https://jessezhang92.github.io/' target='_blank'>Zhenyu Zhang<sup>1</sup></a>,&emsp;
14
+ <a href='https://scholar.google.com.hk/citations?user=6CIDtZQAAAAJ&hl=zh-CN' target='_blank'>Jian Yang<sup>1</sup></a>,&emsp;
15
+ <a href='https://tyshiwo.github.io/index.html' target='_blank'>Ying Tai<sup>1&#8224</sup></a>
16
+ </div>
17
+ <div>
18
+ <sup>1</sup>Nanjing University,&emsp;<sup>2</sup>ByteDance,&emsp; <sup>3</sup>Southwest University
19
+ </div>
20
+ <div>
21
+ <h4 align="center">
22
+ <a href="https://nju-pcalab.github.io/projects/STAR" target='_blank'>
23
+ <img src="https://img.shields.io/badge/🌟-Project%20Page-blue">
24
+ </a>
25
+ <a href="https://arxiv.org/abs/2407.07667" target='_blank'>
26
+ <img src="https://img.shields.io/badge/arXiv-2312.06640-b31b1b.svg">
27
+ </a>
28
+ <a href="https://youtu.be/hx0zrql-SrU" target='_blank'>
29
+ <img src="https://img.shields.io/badge/Demo%20Video-%23FF0000.svg?logo=YouTube&logoColor=white">
30
+ </a>
31
+ </h4>
32
+ </div>
33
+ </div>
34
+
35
+
36
+ ### 🔆 Updates
37
+ - **2024.12.01** The pretrained STAR model (I2VGen-XL version) and inference code have been released.
38
+
39
+
40
+ ## 🔎 Method Overview
41
+ ![STAR](assets/overview.png)
42
+
43
+
44
+ ## 📷 Results Display
45
+ ![STAR](assets/teaser.png)
46
+ ![STAR](assets/real_world.png)
47
+ 👀 More visual results can be found in our [Project Page](https://nju-pcalab.github.io/projects/STAR) and [Video Demo](https://youtu.be/hx0zrql-SrU).
48
+
49
+
50
+ ## ⚙️ Dependencies and Installation
51
+ ```
52
+ ## git clone this repository
53
+ git clone https://github.com/NJU-PCALab/STAR.git
54
+ cd STAR
55
+
56
+ ## create an environment
57
+ conda create -n star python=3.10
58
+ conda activate star
59
+ pip install -r requirements.txt
60
+ sudo apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
61
+ ```
62
+
63
+ ## 🚀 Inference
64
+ #### Step 1: Download the pretrained model STAR from [HuggingFace](https://huggingface.co/SherryX/STAR).
65
+ We provide two verisions, `heavy_deg.pt` for heavy degraded videos and `light_deg.pt` for light degraded videos (e.g., the low-resolution video downloaded from video websites).
66
+
67
+ You can put the weight into `pretrained_weight/`.
68
+
69
+
70
+ #### Step 2: Prepare testing data
71
+ You can put the testing videos in the `input/video/`.
72
+
73
+ As for the prompt, there are three options: 1. No prompt. 2. Automatically generate a prompt [using Pllava](https://github.com/hpcaitech/Open-Sora/tree/main/tools/caption#pllava-captioning). 3. Manually write the prompt. You can put the txt file in the `input/text/`.
74
+
75
+
76
+ #### Step 3: Change the path
77
+ You need to change the paths in `video_super_resolution/scripts/inference_sr.sh` to your local corresponding paths, including `video_folder_path`, `txt_file_path`, `model_path`, and `save_dir`.
78
+
79
+
80
+ #### Step 4: Running inference command
81
+ ```
82
+ bash video_super_resolution/scripts/inference_sr.sh
83
+ ```
84
+
85
+
86
+ ## ❤️ Acknowledgments
87
+ This project is based on [I2VGen-XL](https://github.com/ali-vilab/VGen), [VEnhancer](https://github.com/Vchitect/VEnhancer) and [CogVideoX](https://github.com/THUDM/CogVideo). Thanks for their awesome works.
88
+
89
+
90
+ ## 🎓Citations
91
+ If our project helps your research or work, please consider citing our paper:
92
+
93
+ ```
94
+ @misc{xie2024addsr,
95
+ title={AddSR: Accelerating Diffusion-based Blind Super-Resolution with Adversarial Diffusion Distillation},
96
+ author={Rui Xie and Ying Tai and Kai Zhang and Zhenyu Zhang and Jun Zhou and Jian Yang},
97
+ year={2024},
98
+ eprint={2404.01717},
99
+ archivePrefix={arXiv},
100
+ primaryClass={cs.CV}
101
+ }
102
+ ```
103
+
104
+
105
+ ## 📧 Contact
106
+ If you have any inquiries, please don't hesitate to reach out via email at `ruixie0097@gmail.com`
__pycache__/inference_utils.cpython-39.pyc ADDED
Binary file (5.07 kB). View file
 
inference_utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import tempfile
4
+ import cv2
5
+ import torch
6
+ from PIL import Image
7
+ from typing import Mapping
8
+ from einops import rearrange
9
+ import numpy as np
10
+ import torchvision.transforms.functional as transforms_F
11
+ from video_to_video.utils.logger import get_logger
12
+
13
+ logger = get_logger()
14
+
15
+
16
+ def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
17
+ mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
18
+ std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
19
+ video = video.mul_(std).add_(mean)
20
+ video.clamp_(0, 1)
21
+ video = video * 255.0
22
+ images = rearrange(video, 'b c f h w -> b f h w c')[0]
23
+ return images
24
+
25
+
26
+ def preprocess(input_frames):
27
+ out_frame_list = []
28
+ for pointer in range(len(input_frames)):
29
+ frame = input_frames[pointer]
30
+ frame = frame[:, :, ::-1]
31
+ frame = Image.fromarray(frame.astype('uint8')).convert('RGB')
32
+ frame = transforms_F.to_tensor(frame)
33
+ out_frame_list.append(frame)
34
+ out_frames = torch.stack(out_frame_list, dim=0)
35
+ out_frames.clamp_(0, 1)
36
+ mean = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1)
37
+ std = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1)
38
+ out_frames.sub_(mean.view(1, -1, 1, 1)).div_(std.view(1, -1, 1, 1))
39
+ return out_frames
40
+
41
+
42
+ def adjust_resolution(h, w, up_scale):
43
+ if h*up_scale < 720:
44
+ up_s = 720/h
45
+ target_h = int(up_s*h//2*2)
46
+ target_w = int(up_s*w//2*2)
47
+ elif h*w*up_scale*up_scale > 1280*2048:
48
+ up_s = np.sqrt(1280*2048/(h*w))
49
+ target_h = int(up_s*h//2*2)
50
+ target_w = int(up_s*w//2*2)
51
+ else:
52
+ target_h = int(up_scale*h//2*2)
53
+ target_w = int(up_scale*w//2*2)
54
+ return (target_h, target_w)
55
+
56
+
57
+ def make_mask_cond(in_f_num, interp_f_num):
58
+ mask_cond = []
59
+ interp_cond = [-1 for _ in range(interp_f_num)]
60
+ for i in range(in_f_num):
61
+ mask_cond.append(i)
62
+ if i != in_f_num - 1:
63
+ mask_cond += interp_cond
64
+ return mask_cond
65
+
66
+
67
+ def load_video(vid_path):
68
+ capture = cv2.VideoCapture(vid_path)
69
+ _fps = capture.get(cv2.CAP_PROP_FPS)
70
+ _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
71
+ pointer = 0
72
+ frame_list = []
73
+ stride = 1
74
+ while len(frame_list) < _total_frame_num:
75
+ ret, frame = capture.read()
76
+ pointer += 1
77
+ if (not ret) or (frame is None):
78
+ break
79
+ if pointer >= _total_frame_num + 1:
80
+ break
81
+ if pointer % stride == 0:
82
+ frame_list.append(frame)
83
+ capture.release()
84
+ return frame_list, _fps
85
+
86
+
87
+ def save_video(video, save_dir, file_name, fps=16.0):
88
+ output_path = os.path.join(save_dir, file_name)
89
+ images = [(img.numpy()).astype('uint8') for img in video]
90
+ temp_dir = tempfile.mkdtemp()
91
+
92
+ for fid, frame in enumerate(images):
93
+ tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1))
94
+ cv2.imwrite(tpth, frame[:, :, ::-1])
95
+
96
+ tmp_path = os.path.join(save_dir, 'tmp.mp4')
97
+ cmd = f'ffmpeg -y -f image2 -framerate {fps} -i {temp_dir}/%06d.png \
98
+ -vcodec libx264 -preset ultrafast -crf 0 -pix_fmt yuv420p {tmp_path}'
99
+
100
+ status, output = subprocess.getstatusoutput(cmd)
101
+ if status != 0:
102
+ logger.error('Save Video Error with {}'.format(output))
103
+
104
+ os.system(f'rm -rf {temp_dir}')
105
+ os.rename(tmp_path, output_path)
106
+
107
+
108
+
109
+ def collate_fn(data, device):
110
+ """Prepare the input just before the forward function.
111
+ This method will move the tensors to the right device.
112
+ Usually this method does not need to be overridden.
113
+
114
+ Args:
115
+ data: The data out of the dataloader.
116
+ device: The device to move data to.
117
+
118
+ Returns: The processed data.
119
+
120
+ """
121
+ from torch.utils.data.dataloader import default_collate
122
+
123
+ def get_class_name(obj):
124
+ return obj.__class__.__name__
125
+
126
+ if isinstance(data, dict) or isinstance(data, Mapping):
127
+ return type(data)({
128
+ k: collate_fn(v, device) if k != 'img_metas' else v
129
+ for k, v in data.items()
130
+ })
131
+ elif isinstance(data, (tuple, list)):
132
+ if 0 == len(data):
133
+ return torch.Tensor([])
134
+ if isinstance(data[0], (int, float)):
135
+ return default_collate(data).to(device)
136
+ else:
137
+ return type(data)(collate_fn(v, device) for v in data)
138
+ elif isinstance(data, np.ndarray):
139
+ if data.dtype.type is np.str_:
140
+ return data
141
+ else:
142
+ return collate_fn(torch.from_numpy(data), device)
143
+ elif isinstance(data, torch.Tensor):
144
+ return data.to(device)
145
+ elif isinstance(data, (bytes, str, int, float, bool, type(None))):
146
+ return data
147
+ else:
148
+ raise ValueError(f'Unsupported data type {type(data)}')
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ torchaudio==2.0.2
4
+ opencv-python==4.10.0.84
5
+ easydict==1.13
6
+ einops==0.8.0
7
+ open-clip-torch==2.20.0
8
+ xformers==0.0.21
9
+ fairscale==0.4.13
10
+ torchsde==0.2.6
11
+ pytorch-lightning==2.0.1
12
+ diffusers==0.30.0
13
+ huggingface_hub==0.23.3
14
+ gradio==4.41.0
15
+ numpy==1.24
video_super_resolution/__pycache__/color_fix.cpython-39.pyc ADDED
Binary file (4.01 kB). View file
 
video_super_resolution/color_fix.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # --------------------------------------------------------------------------------
3
+ # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
4
+ # --------------------------------------------------------------------------------
5
+ '''
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torch import Tensor
10
+ from torch.nn import functional as F
11
+
12
+ from torchvision.transforms import ToTensor, ToPILImage
13
+ from einops import rearrange
14
+
15
+ def adain_color_fix(target: Image, source: Image):
16
+ # Convert images to tensors
17
+ target = rearrange(target, 'T H W C -> T C H W') / 255
18
+ source = (source + 1) / 2
19
+
20
+ # Apply adaptive instance normalization
21
+ result_tensor_list = []
22
+ for i in range(0, target.shape[0]):
23
+ result_tensor_list.append(adaptive_instance_normalization(target[i].unsqueeze(0), source[i].unsqueeze(0)))
24
+
25
+ # Convert tensor back to image
26
+ result_tensor = torch.cat(result_tensor_list, dim=0).clamp_(0.0, 1.0)
27
+ result_video = rearrange(result_tensor, "T C H W -> T H W C") * 255
28
+
29
+ return result_video
30
+
31
+ def wavelet_color_fix(target, source):
32
+ # Convert images to tensors
33
+ target = rearrange(target, 'T H W C -> T C H W') / 255
34
+ source = (source + 1) / 2
35
+
36
+ # Apply wavelet reconstruction
37
+ result_tensor_list = []
38
+ for i in range(0, target.shape[0]):
39
+ result_tensor_list.append(wavelet_reconstruction(target[i].unsqueeze(0), source[i].unsqueeze(0)))
40
+
41
+ # Convert tensor back to image
42
+ result_tensor = torch.cat(result_tensor_list, dim=0).clamp_(0.0, 1.0)
43
+ result_video = rearrange(result_tensor, "T C H W -> T H W C") * 255
44
+
45
+ return result_video
46
+
47
+ def calc_mean_std(feat: Tensor, eps=1e-5):
48
+ """Calculate mean and std for adaptive_instance_normalization.
49
+ Args:
50
+ feat (Tensor): 4D tensor.
51
+ eps (float): A small value added to the variance to avoid
52
+ divide-by-zero. Default: 1e-5.
53
+ """
54
+ size = feat.size()
55
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
56
+ b, c = size[:2]
57
+ feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
58
+ feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
59
+ feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
60
+ return feat_mean, feat_std
61
+
62
+ def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
63
+ """Adaptive instance normalization.
64
+ Adjust the reference features to have the similar color and illuminations
65
+ as those in the degradate features.
66
+ Args:
67
+ content_feat (Tensor): The reference feature.
68
+ style_feat (Tensor): The degradate features.
69
+ """
70
+ size = content_feat.size()
71
+ style_mean, style_std = calc_mean_std(style_feat)
72
+ content_mean, content_std = calc_mean_std(content_feat)
73
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
74
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
75
+
76
+ def wavelet_blur(image: Tensor, radius: int):
77
+ """
78
+ Apply wavelet blur to the input tensor.
79
+ """
80
+ # input shape: (1, 3, H, W)
81
+ # convolution kernel
82
+ kernel_vals = [
83
+ [0.0625, 0.125, 0.0625],
84
+ [0.125, 0.25, 0.125],
85
+ [0.0625, 0.125, 0.0625],
86
+ ]
87
+ kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
88
+ # add channel dimensions to the kernel to make it a 4D tensor
89
+ kernel = kernel[None, None]
90
+ # repeat the kernel across all input channels
91
+ kernel = kernel.repeat(3, 1, 1, 1)
92
+ image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
93
+ # apply convolution
94
+ output = F.conv2d(image, kernel, groups=3, dilation=radius)
95
+ return output
96
+
97
+ def wavelet_decomposition(image: Tensor, levels=5):
98
+ """
99
+ Apply wavelet decomposition to the input tensor.
100
+ This function only returns the low frequency & the high frequency.
101
+ """
102
+ high_freq = torch.zeros_like(image)
103
+ for i in range(levels):
104
+ radius = 2 ** i
105
+ low_freq = wavelet_blur(image, radius)
106
+ high_freq += (image - low_freq)
107
+ image = low_freq
108
+
109
+ return high_freq, low_freq
110
+
111
+ def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
112
+ """
113
+ Apply wavelet decomposition, so that the content will have the same color as the style.
114
+ """
115
+ # calculate the wavelet decomposition of the content feature
116
+ content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
117
+ del content_low_freq
118
+ # calculate the wavelet decomposition of the style feature
119
+ style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
120
+ del style_high_freq
121
+ # reconstruct the content feature with the style's high frequency
122
+ return content_high_freq + style_low_freq
video_super_resolution/dataset.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import glob
4
+ import torchvision
5
+ from einops import rearrange
6
+ from torch.utils import data as data
7
+ import torch.nn.functional as F
8
+ from torchvision import transforms
9
+ from PIL import Image
10
+
11
+ class PairedCaptionVideoDataset(data.Dataset):
12
+ def __init__(
13
+ self,
14
+ root_folders=None,
15
+ null_text_ratio=0.5,
16
+ num_frames=16
17
+ ):
18
+ super(PairedCaptionVideoDataset, self).__init__()
19
+
20
+ self.null_text_ratio = null_text_ratio
21
+ self.lr_list = []
22
+ self.gt_list = []
23
+ self.tag_path_list = []
24
+ self.num_frames = num_frames
25
+
26
+ # root_folders = root_folders.split(',')
27
+ for root_folder in root_folders:
28
+ lr_path = root_folder +'/lq'
29
+ tag_path = root_folder +'/text'
30
+ gt_path = root_folder +'/gt'
31
+
32
+ self.lr_list += glob.glob(os.path.join(lr_path, '*.mp4'))
33
+ self.gt_list += glob.glob(os.path.join(gt_path, '*.mp4'))
34
+ self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt'))
35
+
36
+ assert len(self.lr_list) == len(self.gt_list)
37
+ assert len(self.lr_list) == len(self.tag_path_list)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ gt_path = self.gt_list[index]
42
+ vframes_gt, _, info = torchvision.io.read_video(filename=gt_path, pts_unit="sec", output_format="TCHW")
43
+ fps = info['video_fps']
44
+ vframes_gt = (rearrange(vframes_gt, "T C H W -> C T H W") / 255) * 2 - 1
45
+ # gt = self.trandform(vframes_gt)
46
+
47
+ lq_path = self.lr_list[index]
48
+ vframes_lq, _, _ = torchvision.io.read_video(filename=lq_path, pts_unit="sec", output_format="TCHW")
49
+ vframes_lq = (rearrange(vframes_lq, "T C H W -> C T H W") / 255) * 2 - 1
50
+ # lq = self.trandform(vframes_lq)
51
+
52
+ if random.random() < self.null_text_ratio:
53
+ tag = ''
54
+ else:
55
+ tag_path = self.tag_path_list[index]
56
+ with open(tag_path, 'r', encoding='utf-8') as file:
57
+ tag = file.read()
58
+
59
+ return {"gt": vframes_gt[:, :self.num_frames, :, :], "lq": vframes_lq[:, :self.num_frames, :, :], "text": tag, 'fps': fps}
60
+
61
+ def __len__(self):
62
+ return len(self.gt_list)
63
+
64
+
65
+ class PairedCaptionImageDataset(data.Dataset):
66
+ def __init__(
67
+ self,
68
+ root_folder=None,
69
+ ):
70
+ super(PairedCaptionImageDataset, self).__init__()
71
+
72
+ self.lr_list = []
73
+ self.gt_list = []
74
+ self.tag_path_list = []
75
+
76
+ lr_path = root_folder +'/sr_bicubic'
77
+ gt_path = root_folder +'/gt'
78
+
79
+ self.lr_list += glob.glob(os.path.join(lr_path, '*.png'))
80
+ self.gt_list += glob.glob(os.path.join(gt_path, '*.png'))
81
+
82
+ assert len(self.lr_list) == len(self.gt_list)
83
+
84
+ self.img_preproc = transforms.Compose([
85
+ transforms.ToTensor(),
86
+ ])
87
+
88
+ # Define the crop size (e.g., 256x256)
89
+ crop_size = (720, 1280)
90
+
91
+ # CenterCrop transform
92
+ self.center_crop = transforms.CenterCrop(crop_size)
93
+
94
+ def __getitem__(self, index):
95
+
96
+ gt_path = self.gt_list[index]
97
+ gt_img = Image.open(gt_path).convert('RGB')
98
+ gt_img = self.center_crop(self.img_preproc(gt_img))
99
+
100
+ lq_path = self.lr_list[index]
101
+ lq_img = Image.open(lq_path).convert('RGB')
102
+ lq_img = self.center_crop(self.img_preproc(lq_img))
103
+
104
+ example = dict()
105
+
106
+ example["lq"] = (lq_img.squeeze(0) * 2.0 - 1.0).unsqueeze(1)
107
+ example["gt"] = (gt_img.squeeze(0) * 2.0 - 1.0).unsqueeze(1)
108
+ example["text"] = ""
109
+
110
+ return example
111
+
112
+ def __len__(self):
113
+ return len(self.gt_list)
video_super_resolution/scripts/inference_sr.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from argparse import ArgumentParser, Namespace
4
+ import json
5
+ from typing import Any, Dict, List, Mapping, Tuple
6
+ from easydict import EasyDict
7
+
8
+ from video_to_video.video_to_video_model import VideoToVideo_sr
9
+ from video_to_video.utils.seed import setup_seed
10
+ from video_to_video.utils.logger import get_logger
11
+ from video_super_resolution.color_fix import adain_color_fix
12
+
13
+ from inference_utils import *
14
+
15
+ logger = get_logger()
16
+
17
+
18
+ class VEnhancer_sr():
19
+ def __init__(self,
20
+ result_dir='./results/',
21
+ file_name='000_video.mp4',
22
+ model_path='',
23
+ solver_mode='fast',
24
+ steps=15,
25
+ guide_scale=7.5,
26
+ upscale=4,
27
+ max_chunk_len=32,
28
+ variant_info=None,
29
+ ):
30
+ self.model_path=model_path
31
+ logger.info('checkpoint_path: {}'.format(self.model_path))
32
+
33
+ self.result_dir = result_dir
34
+ self.file_name = file_name
35
+ os.makedirs(self.result_dir, exist_ok=True)
36
+
37
+ model_cfg = EasyDict(__name__='model_cfg')
38
+ model_cfg.model_path = self.model_path
39
+ self.model = VideoToVideo_sr(model_cfg)
40
+
41
+ steps = 15 if solver_mode == 'fast' else steps
42
+ self.solver_mode=solver_mode
43
+ self.steps=steps
44
+ self.guide_scale=guide_scale
45
+ self.upscale = upscale
46
+ self.max_chunk_len=max_chunk_len
47
+ self.variant_info=variant_info
48
+
49
+ def enhance_a_video(self, video_path, prompt):
50
+ logger.info('input video path: {}'.format(video_path))
51
+ text = prompt
52
+ logger.info('text: {}'.format(text))
53
+ caption = text + self.model.positive_prompt
54
+
55
+ input_frames, input_fps = load_video(video_path)
56
+ in_f_num = len(input_frames)
57
+ logger.info('input frames length: {}'.format(in_f_num))
58
+ logger.info('input fps: {}'.format(input_fps))
59
+
60
+ video_data = preprocess(input_frames)
61
+ _, _, h, w = video_data.shape
62
+ logger.info('input resolution: {}'.format((h, w)))
63
+ target_h, target_w = h * self.upscale, w * self.upscale # adjust_resolution(h, w, up_scale=4)
64
+ logger.info('target resolution: {}'.format((target_h, target_w)))
65
+
66
+ pre_data = {'video_data': video_data, 'y': caption}
67
+ pre_data['target_res'] = (target_h, target_w)
68
+
69
+ total_noise_levels = 900
70
+ setup_seed(666)
71
+
72
+ with torch.no_grad():
73
+ data_tensor = collate_fn(pre_data, 'cuda:0')
74
+ output = self.model.test(data_tensor, total_noise_levels, steps=self.steps, \
75
+ solver_mode=self.solver_mode, guide_scale=self.guide_scale, \
76
+ max_chunk_len=self.max_chunk_len
77
+ )
78
+
79
+ output = tensor2vid(output)
80
+
81
+ # Using color fix
82
+ output = adain_color_fix(output, video_data)
83
+
84
+ save_video(output, self.result_dir, self.file_name, fps=input_fps)
85
+ return os.path.join(self.result_dir, self.file_name)
86
+
87
+
88
+ def parse_args():
89
+ parser = ArgumentParser()
90
+
91
+ parser.add_argument("--input_path", required=True, type=str, help="input video path")
92
+ parser.add_argument("--save_dir", type=str, default='results', help="save directory")
93
+ parser.add_argument("--file_name", type=str, help="file name")
94
+ parser.add_argument("--model_path", type=str, default='./pretrained_weight/model.pt', help="model path")
95
+ parser.add_argument("--prompt", type=str, default='a good video', help="prompt")
96
+ parser.add_argument("--upscale", type=int, default=4, help='up-scale')
97
+ parser.add_argument("--max_chunk_len", type=int, default=32, help='max_chunk_len')
98
+ parser.add_argument("--variant_info", type=str, default=None, help='information of inference strategy')
99
+
100
+ parser.add_argument("--cfg", type=float, default=7.5)
101
+ parser.add_argument("--solver_mode", type=str, default='fast', help='fast | normal')
102
+ parser.add_argument("--steps", type=int, default=15)
103
+
104
+ return parser.parse_args()
105
+
106
+ def main():
107
+
108
+ args = parse_args()
109
+
110
+ input_path = args.input_path
111
+ prompt = args.prompt
112
+ model_path = args.model_path
113
+ save_dir = args.save_dir
114
+ file_name = args.file_name
115
+ upscale = args.upscale
116
+ max_chunk_len = args.max_chunk_len
117
+
118
+ steps = args.steps
119
+ solver_mode = args.solver_mode
120
+ guide_scale = args.cfg
121
+
122
+ assert solver_mode in ('fast', 'normal')
123
+
124
+ venhancer_sr = VEnhancer_sr(
125
+ result_dir=save_dir,
126
+ file_name=file_name, # new added
127
+ model_path=model_path,
128
+ solver_mode=solver_mode,
129
+ steps=steps,
130
+ guide_scale=guide_scale,
131
+ upscale=upscale,
132
+ max_chunk_len=max_chunk_len,
133
+ variant_info=None,
134
+ )
135
+
136
+ venhancer_sr.enhance_a_video(input_path, prompt)
137
+
138
+
139
+ if __name__ == '__main__':
140
+ main()
video_super_resolution/scripts/inference_sr.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Folder paths
4
+ video_folder_path='./input/video'
5
+ txt_file_path='./input/text/prompt.txt'
6
+
7
+ # Get all .mp4 files in the folder using find to handle special characters
8
+ mapfile -t mp4_files < <(find "$video_folder_path" -type f -name "*.mp4")
9
+
10
+ # Print the list of MP4 files
11
+ echo "MP4 files to be processed:"
12
+ for mp4_file in "${mp4_files[@]}"; do
13
+ echo "$mp4_file"
14
+ done
15
+
16
+ # Read lines from the text file, skipping empty lines
17
+ mapfile -t lines < <(grep -v '^\s*$' "$txt_file_path")
18
+
19
+ # List of frame counts
20
+ frame_length=32
21
+
22
+ # Debugging output
23
+ echo "Number of MP4 files: ${#mp4_files[@]}"
24
+ echo "Number of lines in the text file: ${#lines[@]}"
25
+
26
+ # Ensure the number of video files matches the number of lines
27
+ if [ ${#mp4_files[@]} -ne ${#lines[@]} ]; then
28
+ echo "Number of MP4 files and lines in the text file do not match."
29
+ exit 1
30
+ fi
31
+
32
+ # Loop through video files and corresponding lines
33
+ for i in "${!mp4_files[@]}"; do
34
+ mp4_file="${mp4_files[$i]}"
35
+ line="${lines[$i]}"
36
+
37
+ # Extract the filename without the extension
38
+ file_name=$(basename "$mp4_file" .mp4)
39
+
40
+ echo "Processing video file: $mp4_file with prompt: $line"
41
+
42
+ # Run Python script with parameters
43
+ python \
44
+ ./video_super_resolution/scripts/inference_sr.py \
45
+ --solver_mode 'fast' \
46
+ --steps 15 \
47
+ --input_path "${mp4_file}" \
48
+ --model_path /mnt/bn/videodataset/VSR/pretrained_models/STAR/model.pt \
49
+ --prompt "${line}" \
50
+ --upscale 4 \
51
+ --max_chunk_len ${frame_length} \
52
+ --file_name "${file_name}.mp4" \
53
+ --save_dir ./results
54
+ done
55
+
56
+ echo "All videos processed successfully."
video_to_video/__init__.py ADDED
File without changes
video_to_video/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (152 Bytes). View file
 
video_to_video/__pycache__/video_to_video_model.cpython-39.pyc ADDED
Binary file (6.11 kB). View file
 
video_to_video/diffusion/__init__.py ADDED
File without changes
video_to_video/diffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (162 Bytes). View file
 
video_to_video/diffusion/__pycache__/diffusion_sdedit.cpython-39.pyc ADDED
Binary file (10.4 kB). View file
 
video_to_video/diffusion/__pycache__/schedules_sdedit.cpython-39.pyc ADDED
Binary file (2.68 kB). View file
 
video_to_video/diffusion/__pycache__/solvers_sdedit.cpython-39.pyc ADDED
Binary file (6.18 kB). View file
 
video_to_video/diffusion/diffusion_sdedit.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+
5
+ from .schedules_sdedit import karras_schedule
6
+ from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun
7
+
8
+ from video_to_video.utils.logger import get_logger
9
+
10
+ logger = get_logger()
11
+
12
+ __all__ = ['GaussianDiffusion']
13
+
14
+
15
+ def _i(tensor, t, x):
16
+ shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
17
+ return tensor[t.to(tensor.device)].view(shape).to(x.device)
18
+
19
+ class GaussianDiffusion(object):
20
+
21
+ def __init__(self, sigmas):
22
+ self.sigmas = sigmas
23
+ self.alphas = torch.sqrt(1 - sigmas**2)
24
+ self.num_timesteps = len(sigmas)
25
+
26
+ def diffuse(self, x0, t, noise=None):
27
+ noise = torch.randn_like(x0) if noise is None else noise
28
+ xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
29
+
30
+ return xt
31
+
32
+ def get_velocity(self, x0, xt, t):
33
+ sigmas = _i(self.sigmas, t, xt)
34
+ alphas = _i(self.alphas, t, xt)
35
+ velocity = (alphas * xt - x0) / sigmas
36
+ return velocity
37
+
38
+ def get_x0(self, v, xt, t):
39
+ sigmas = _i(self.sigmas, t, xt)
40
+ alphas = _i(self.alphas, t, xt)
41
+ x0 = alphas * xt - sigmas * v
42
+ return x0
43
+
44
+ def denoise(self,
45
+ xt,
46
+ t,
47
+ s,
48
+ model,
49
+ model_kwargs={},
50
+ guide_scale=None,
51
+ guide_rescale=None,
52
+ clamp=None,
53
+ percentile=None,
54
+ variant_info=None,):
55
+ s = t - 1 if s is None else s
56
+
57
+ # hyperparams
58
+ sigmas = _i(self.sigmas, t, xt)
59
+ alphas = _i(self.alphas, t, xt)
60
+ alphas_s = _i(self.alphas, s.clamp(0), xt)
61
+ alphas_s[s < 0] = 1.
62
+ sigmas_s = torch.sqrt(1 - alphas_s**2)
63
+
64
+ # precompute variables
65
+ betas = 1 - (alphas / alphas_s)**2
66
+ coef1 = betas * alphas_s / sigmas**2
67
+ coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2)
68
+ var = betas * (sigmas_s / sigmas)**2
69
+ log_var = torch.log(var).clamp_(-20, 20)
70
+
71
+ # prediction
72
+ if guide_scale is None:
73
+ assert isinstance(model_kwargs, dict)
74
+ out = model(xt, t=t, **model_kwargs)
75
+ else:
76
+ # classifier-free guidance
77
+ assert isinstance(model_kwargs, list)
78
+ if len(model_kwargs) > 3:
79
+ y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5])
80
+ else:
81
+ y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], variant_info=variant_info)
82
+ if guide_scale == 1.:
83
+ out = y_out
84
+ else:
85
+ if len(model_kwargs) > 3:
86
+ u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5])
87
+ else:
88
+ u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], variant_info=variant_info)
89
+ out = u_out + guide_scale * (y_out - u_out)
90
+
91
+ if guide_rescale is not None:
92
+ assert guide_rescale >= 0 and guide_rescale <= 1
93
+ ratio = (
94
+ y_out.flatten(1).std(dim=1) / # noqa
95
+ (out.flatten(1).std(dim=1) + 1e-12)
96
+ ).view((-1, ) + (1, ) * (y_out.ndim - 1))
97
+ out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
98
+
99
+ x0 = alphas * xt - sigmas * out
100
+
101
+ # restrict the range of x0
102
+ if percentile is not None:
103
+ assert percentile > 0 and percentile <= 1
104
+ s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
105
+ s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
106
+ x0 = torch.min(s, torch.max(-s, x0)) / s
107
+ elif clamp is not None:
108
+ x0 = x0.clamp(-clamp, clamp)
109
+
110
+ # recompute eps using the restricted x0
111
+ eps = (xt - alphas * x0) / sigmas
112
+
113
+ # compute mu (mean of posterior distribution) using the restricted x0
114
+ mu = coef1 * x0 + coef2 * xt
115
+ return mu, var, log_var, x0, eps
116
+
117
+
118
+ @torch.no_grad()
119
+ def sample(self,
120
+ noise,
121
+ model,
122
+ model_kwargs={},
123
+ condition_fn=None,
124
+ guide_scale=None,
125
+ guide_rescale=None,
126
+ clamp=None,
127
+ percentile=None,
128
+ solver='euler_a',
129
+ solver_mode='fast',
130
+ steps=20,
131
+ t_max=None,
132
+ t_min=None,
133
+ discretization=None,
134
+ discard_penultimate_step=None,
135
+ return_intermediate=None,
136
+ show_progress=False,
137
+ seed=-1,
138
+ chunk_inds=None,
139
+ **kwargs):
140
+ # sanity check
141
+ assert isinstance(steps, (int, torch.LongTensor))
142
+ assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
143
+ assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
144
+ assert discretization in (None, 'leading', 'linspace', 'trailing')
145
+ assert discard_penultimate_step in (None, True, False)
146
+ assert return_intermediate in (None, 'x0', 'xt')
147
+
148
+ # function of diffusion solver
149
+ solver_fn = {
150
+ 'heun': sample_heun,
151
+ 'dpmpp_2m_sde': sample_dpmpp_2m_sde
152
+ }[solver]
153
+
154
+ # options
155
+ schedule = 'karras' if 'karras' in solver else None
156
+ discretization = discretization or 'linspace'
157
+ seed = seed if seed >= 0 else random.randint(0, 2**31)
158
+ if isinstance(steps, torch.LongTensor):
159
+ discard_penultimate_step = False
160
+ if discard_penultimate_step is None:
161
+ discard_penultimate_step = True if solver in (
162
+ 'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
163
+ 'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
164
+
165
+ # function for denoising xt to get x0
166
+ intermediates = []
167
+
168
+ def model_fn(xt, sigma):
169
+ # denoising
170
+ t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
171
+ x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
172
+ guide_rescale, clamp, percentile)[-2]
173
+
174
+ # collect intermediate outputs
175
+ if return_intermediate == 'xt':
176
+ intermediates.append(xt)
177
+ elif return_intermediate == 'x0':
178
+ intermediates.append(x0)
179
+ return x0
180
+
181
+ mask_cond = model_kwargs[3]['mask_cond']
182
+ def model_chunk_fn(xt, sigma):
183
+ # denoising
184
+ t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
185
+ O_LEN = chunk_inds[0][-1]-chunk_inds[1][0]
186
+ cut_f_ind = O_LEN//2
187
+
188
+ results_list = []
189
+ for i in range(len(chunk_inds)):
190
+ ind_start, ind_end = chunk_inds[i]
191
+ xt_chunk = xt[:,:,ind_start:ind_end].clone()
192
+ cur_f = xt_chunk.size(2)
193
+ model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone()
194
+ x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale,
195
+ guide_rescale, clamp, percentile)[-2]
196
+ if i == 0:
197
+ results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN])
198
+ elif i == len(chunk_inds)-1:
199
+ results_list.append(x0_chunk[:,:,cut_f_ind:])
200
+ else:
201
+ results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN])
202
+ x0 = torch.concat(results_list, dim=2)
203
+ torch.cuda.empty_cache()
204
+ return x0
205
+
206
+ # get timesteps
207
+ if isinstance(steps, int):
208
+ steps += 1 if discard_penultimate_step else 0
209
+ t_max = self.num_timesteps - 1 if t_max is None else t_max
210
+ t_min = 0 if t_min is None else t_min
211
+
212
+ # discretize timesteps
213
+ if discretization == 'leading':
214
+ steps = torch.arange(t_min, t_max + 1,
215
+ (t_max - t_min + 1) / steps).flip(0)
216
+ elif discretization == 'linspace':
217
+ steps = torch.linspace(t_max, t_min, steps)
218
+ elif discretization == 'trailing':
219
+ steps = torch.arange(t_max, t_min - 1,
220
+ -((t_max - t_min + 1) / steps))
221
+ if solver_mode == 'fast':
222
+ t_mid = 500
223
+ steps1 = torch.arange(t_max, t_mid - 1,
224
+ -((t_max - t_mid + 1) / 4))
225
+ steps2 = torch.arange(t_mid, t_min - 1,
226
+ -((t_mid - t_min + 1) / 11))
227
+ steps = torch.concat([steps1, steps2])
228
+ else:
229
+ raise NotImplementedError(
230
+ f'{discretization} discretization not implemented')
231
+ steps = steps.clamp_(t_min, t_max)
232
+ steps = torch.as_tensor(
233
+ steps, dtype=torch.float32, device=noise.device)
234
+
235
+ # get sigmas
236
+ sigmas = self._t_to_sigma(steps)
237
+ sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
238
+ if schedule == 'karras':
239
+ if sigmas[0] == float('inf'):
240
+ sigmas = karras_schedule(
241
+ n=len(steps) - 1,
242
+ sigma_min=sigmas[sigmas > 0].min().item(),
243
+ sigma_max=sigmas[sigmas < float('inf')].max().item(),
244
+ rho=7.).to(sigmas)
245
+ sigmas = torch.cat([
246
+ sigmas.new_tensor([float('inf')]), sigmas,
247
+ sigmas.new_zeros([1])
248
+ ])
249
+ else:
250
+ sigmas = karras_schedule(
251
+ n=len(steps),
252
+ sigma_min=sigmas[sigmas > 0].min().item(),
253
+ sigma_max=sigmas.max().item(),
254
+ rho=7.).to(sigmas)
255
+ sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
256
+ if discard_penultimate_step:
257
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
258
+
259
+ fn = model_chunk_fn if chunk_inds is not None else model_fn
260
+ x0 = solver_fn(
261
+ noise, fn, sigmas, show_progress=show_progress, **kwargs)
262
+ return (x0, intermediates) if return_intermediate is not None else x0
263
+
264
+ @torch.no_grad()
265
+ def sample_sr(self,
266
+ noise,
267
+ model,
268
+ model_kwargs={},
269
+ condition_fn=None,
270
+ guide_scale=None,
271
+ guide_rescale=None,
272
+ clamp=None,
273
+ percentile=None,
274
+ solver='euler_a',
275
+ solver_mode='fast',
276
+ steps=20,
277
+ t_max=None,
278
+ t_min=None,
279
+ discretization=None,
280
+ discard_penultimate_step=None,
281
+ return_intermediate=None,
282
+ show_progress=False,
283
+ seed=-1,
284
+ chunk_inds=None,
285
+ variant_info=None,
286
+ **kwargs):
287
+ # sanity check
288
+ assert isinstance(steps, (int, torch.LongTensor))
289
+ assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
290
+ assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
291
+ assert discretization in (None, 'leading', 'linspace', 'trailing')
292
+ assert discard_penultimate_step in (None, True, False)
293
+ assert return_intermediate in (None, 'x0', 'xt')
294
+
295
+ # function of diffusion solver
296
+ solver_fn = {
297
+ 'heun': sample_heun,
298
+ 'dpmpp_2m_sde': sample_dpmpp_2m_sde
299
+ }[solver]
300
+
301
+ # options
302
+ schedule = 'karras' if 'karras' in solver else None
303
+ discretization = discretization or 'linspace'
304
+ seed = seed if seed >= 0 else random.randint(0, 2**31)
305
+ if isinstance(steps, torch.LongTensor):
306
+ discard_penultimate_step = False
307
+ if discard_penultimate_step is None:
308
+ discard_penultimate_step = True if solver in (
309
+ 'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
310
+ 'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
311
+
312
+ # function for denoising xt to get x0
313
+ intermediates = []
314
+
315
+ def model_fn(xt, sigma, variant_info=None):
316
+ # denoising
317
+ t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
318
+ x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
319
+ guide_rescale, clamp, percentile, variant_info=variant_info)[-2]
320
+
321
+ # collect intermediate outputs
322
+ if return_intermediate == 'xt':
323
+ intermediates.append(xt)
324
+ elif return_intermediate == 'x0':
325
+ print('add intermediate outputs x0')
326
+ intermediates.append(x0)
327
+ return x0
328
+
329
+ # mask_cond = model_kwargs[3]['mask_cond']
330
+ def model_chunk_fn(xt, sigma, variant_info=None):
331
+ # denoising
332
+ t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
333
+ O_LEN = chunk_inds[0][-1]-chunk_inds[1][0]
334
+ cut_f_ind = O_LEN//2
335
+
336
+ results_list = []
337
+ for i in range(len(chunk_inds)):
338
+ ind_start, ind_end = chunk_inds[i]
339
+ xt_chunk = xt[:,:,ind_start:ind_end].clone()
340
+ model_kwargs[2]['hint_chunk'] = model_kwargs[2]['hint'][:,:,ind_start:ind_end].clone() # new added
341
+ cur_f = xt_chunk.size(2)
342
+ # model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone()
343
+ x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale,
344
+ guide_rescale, clamp, percentile, variant_info=variant_info)[-2]
345
+ if i == 0:
346
+ results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN])
347
+ elif i == len(chunk_inds)-1:
348
+ results_list.append(x0_chunk[:,:,cut_f_ind:])
349
+ else:
350
+ results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN])
351
+ x0 = torch.concat(results_list, dim=2)
352
+ torch.cuda.empty_cache()
353
+ return x0
354
+
355
+ # get timesteps
356
+ if isinstance(steps, int):
357
+ steps += 1 if discard_penultimate_step else 0
358
+ t_max = self.num_timesteps - 1 if t_max is None else t_max
359
+ t_min = 0 if t_min is None else t_min
360
+
361
+ # discretize timesteps
362
+ if discretization == 'leading':
363
+ steps = torch.arange(t_min, t_max + 1,
364
+ (t_max - t_min + 1) / steps).flip(0)
365
+ elif discretization == 'linspace':
366
+ steps = torch.linspace(t_max, t_min, steps)
367
+ elif discretization == 'trailing':
368
+ steps = torch.arange(t_max, t_min - 1,
369
+ -((t_max - t_min + 1) / steps))
370
+ if solver_mode == 'fast':
371
+ t_mid = 500
372
+ steps1 = torch.arange(t_max, t_mid - 1,
373
+ -((t_max - t_mid + 1) / 4))
374
+ steps2 = torch.arange(t_mid, t_min - 1,
375
+ -((t_mid - t_min + 1) / 11))
376
+ steps = torch.concat([steps1, steps2])
377
+ else:
378
+ raise NotImplementedError(
379
+ f'{discretization} discretization not implemented')
380
+ steps = steps.clamp_(t_min, t_max)
381
+ steps = torch.as_tensor(
382
+ steps, dtype=torch.float32, device=noise.device)
383
+
384
+ # get sigmas
385
+ sigmas = self._t_to_sigma(steps)
386
+ sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
387
+ if schedule == 'karras':
388
+ if sigmas[0] == float('inf'):
389
+ sigmas = karras_schedule(
390
+ n=len(steps) - 1,
391
+ sigma_min=sigmas[sigmas > 0].min().item(),
392
+ sigma_max=sigmas[sigmas < float('inf')].max().item(),
393
+ rho=7.).to(sigmas)
394
+ sigmas = torch.cat([
395
+ sigmas.new_tensor([float('inf')]), sigmas,
396
+ sigmas.new_zeros([1])
397
+ ])
398
+ else:
399
+ sigmas = karras_schedule(
400
+ n=len(steps),
401
+ sigma_min=sigmas[sigmas > 0].min().item(),
402
+ sigma_max=sigmas.max().item(),
403
+ rho=7.).to(sigmas)
404
+ sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
405
+ if discard_penultimate_step:
406
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
407
+
408
+
409
+ fn = model_chunk_fn if chunk_inds is not None else model_fn
410
+ x0 = solver_fn(
411
+ noise, fn, sigmas, variant_info=variant_info, show_progress=show_progress, **kwargs)
412
+ return (x0, intermediates) if return_intermediate is not None else x0
413
+
414
+
415
+ def _sigma_to_t(self, sigma):
416
+ if sigma == float('inf'):
417
+ t = torch.full_like(sigma, len(self.sigmas) - 1)
418
+ else:
419
+ log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
420
+ (1 - self.sigmas**2)).log().to(sigma)
421
+ log_sigma = sigma.log()
422
+ dists = log_sigma - log_sigmas[:, None]
423
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
424
+ max=log_sigmas.shape[0] - 2)
425
+ high_idx = low_idx + 1
426
+ low, high = log_sigmas[low_idx], log_sigmas[high_idx]
427
+ w = (low - log_sigma) / (low - high)
428
+ w = w.clamp(0, 1)
429
+ t = (1 - w) * low_idx + w * high_idx
430
+ t = t.view(sigma.shape)
431
+ if t.ndim == 0:
432
+ t = t.unsqueeze(0)
433
+ return t
434
+
435
+ def _t_to_sigma(self, t):
436
+ t = t.float()
437
+ low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
438
+ log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
439
+ (1 - self.sigmas**2)).log().to(t)
440
+ log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
441
+ log_sigma[torch.isnan(log_sigma)
442
+ | torch.isinf(log_sigma)] = float('inf')
443
+ return log_sigma.exp()
video_to_video/diffusion/schedules_sdedit.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import math
4
+
5
+ import torch
6
+
7
+
8
+ def betas_to_sigmas(betas):
9
+ return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
10
+
11
+
12
+ def sigmas_to_betas(sigmas):
13
+ square_alphas = 1 - sigmas**2
14
+ betas = 1 - torch.cat(
15
+ [square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
16
+ return betas
17
+
18
+
19
+ def logsnrs_to_sigmas(logsnrs):
20
+ return torch.sqrt(torch.sigmoid(-logsnrs))
21
+
22
+
23
+ def sigmas_to_logsnrs(sigmas):
24
+ square_sigmas = sigmas**2
25
+ return torch.log(square_sigmas / (1 - square_sigmas))
26
+
27
+
28
+ def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
29
+ t_min = math.atan(math.exp(-0.5 * logsnr_min))
30
+ t_max = math.atan(math.exp(-0.5 * logsnr_max))
31
+ t = torch.linspace(1, 0, n)
32
+ logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
33
+ return logsnrs
34
+
35
+
36
+ def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
37
+ logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
38
+ logsnrs += 2 * math.log(1 / scale)
39
+ return logsnrs
40
+
41
+
42
+ def _logsnr_cosine_interp(n,
43
+ logsnr_min=-15,
44
+ logsnr_max=15,
45
+ scale_min=2,
46
+ scale_max=4):
47
+ t = torch.linspace(1, 0, n)
48
+ logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
49
+ logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
50
+ logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
51
+ return logsnrs
52
+
53
+
54
+ def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
55
+ ramp = torch.linspace(1, 0, n)
56
+ min_inv_rho = sigma_min**(1 / rho)
57
+ max_inv_rho = sigma_max**(1 / rho)
58
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
59
+ sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
60
+ return sigmas
61
+
62
+
63
+ def logsnr_cosine_interp_schedule(n,
64
+ logsnr_min=-15,
65
+ logsnr_max=15,
66
+ scale_min=2,
67
+ scale_max=4):
68
+ return logsnrs_to_sigmas(
69
+ _logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max))
70
+
71
+
72
+ def noise_schedule(schedule='logsnr_cosine_interp',
73
+ n=1000,
74
+ zero_terminal_snr=False,
75
+ **kwargs):
76
+ # compute sigmas
77
+ sigmas = {
78
+ 'logsnr_cosine_interp': logsnr_cosine_interp_schedule
79
+ }[schedule](n, **kwargs)
80
+
81
+ # post-processing
82
+ if zero_terminal_snr and sigmas.max() != 1.0:
83
+ scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min())
84
+ sigmas = sigmas.min() + scale * (sigmas - sigmas.min())
85
+ return sigmas
video_to_video/diffusion/solvers_sdedit.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import torch
4
+ import torchsde
5
+ from tqdm.auto import trange
6
+
7
+ from video_to_video.utils.logger import get_logger
8
+
9
+ logger = get_logger()
10
+
11
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
12
+ """
13
+ Calculates the noise level (sigma_down) to step down to and the amount
14
+ of noise to add (sigma_up) when doing an ancestral sampling step.
15
+ """
16
+ if not eta:
17
+ return sigma_to, 0.
18
+ sigma_up = min(
19
+ sigma_to,
20
+ eta * (
21
+ sigma_to**2 * # noqa
22
+ (sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5)
23
+ sigma_down = (sigma_to**2 - sigma_up**2)**0.5
24
+ return sigma_down, sigma_up
25
+
26
+
27
+ def get_scalings(sigma):
28
+ c_out = -sigma
29
+ c_in = 1 / (sigma**2 + 1.**2)**0.5
30
+ return c_out, c_in
31
+
32
+
33
+ @torch.no_grad()
34
+ def sample_heun(noise,
35
+ model,
36
+ sigmas,
37
+ s_churn=0.,
38
+ s_tmin=0.,
39
+ s_tmax=float('inf'),
40
+ s_noise=1.,
41
+ show_progress=True):
42
+ """
43
+ Implements Algorithm 2 (Heun steps) from Karras et al. (2022).
44
+ """
45
+ x = noise * sigmas[0]
46
+ for i in trange(len(sigmas) - 1, disable=not show_progress):
47
+ gamma = 0.
48
+ if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'):
49
+ gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
50
+ eps = torch.randn_like(x) * s_noise
51
+ sigma_hat = sigmas[i] * (gamma + 1)
52
+ if gamma > 0:
53
+ x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5
54
+ if sigmas[i] == float('inf'):
55
+ # Euler method
56
+ denoised = model(noise, sigma_hat)
57
+ x = denoised + sigmas[i + 1] * (gamma + 1) * noise
58
+ else:
59
+ _, c_in = get_scalings(sigma_hat)
60
+ denoised = model(x * c_in, sigma_hat)
61
+ d = (x - denoised) / sigma_hat
62
+ dt = sigmas[i + 1] - sigma_hat
63
+ if sigmas[i + 1] == 0:
64
+ # Euler method
65
+ x = x + d * dt
66
+ else:
67
+ # Heun's method
68
+ x_2 = x + d * dt
69
+ _, c_in = get_scalings(sigmas[i + 1])
70
+ denoised_2 = model(x_2 * c_in, sigmas[i + 1])
71
+ d_2 = (x_2 - denoised_2) / sigmas[i + 1]
72
+ d_prime = (d + d_2) / 2
73
+ x = x + d_prime * dt
74
+ return x
75
+
76
+
77
+ class BatchedBrownianTree:
78
+ """
79
+ A wrapper around torchsde.BrownianTree that enables batches of entropy.
80
+ """
81
+
82
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
83
+ t0, t1, self.sign = self.sort(t0, t1)
84
+ w0 = kwargs.get('w0', torch.zeros_like(x))
85
+ if seed is None:
86
+ seed = torch.randint(0, 2**63 - 1, []).item()
87
+ self.batched = True
88
+ try:
89
+ assert len(seed) == x.shape[0]
90
+ w0 = w0[0]
91
+ except TypeError:
92
+ seed = [seed]
93
+ self.batched = False
94
+ self.trees = [
95
+ torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs)
96
+ for s in seed
97
+ ]
98
+
99
+ @staticmethod
100
+ def sort(a, b):
101
+ return (a, b, 1) if a < b else (b, a, -1)
102
+
103
+ def __call__(self, t0, t1):
104
+ t0, t1, sign = self.sort(t0, t1)
105
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (
106
+ self.sign * sign)
107
+ return w if self.batched else w[0]
108
+
109
+
110
+ class BrownianTreeNoiseSampler:
111
+ """
112
+ A noise sampler backed by a torchsde.BrownianTree.
113
+
114
+ Args:
115
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
116
+ random samples.
117
+ sigma_min (float): The low end of the valid interval.
118
+ sigma_max (float): The high end of the valid interval.
119
+ seed (int or List[int]): The random seed. If a list of seeds is
120
+ supplied instead of a single integer, then the noise sampler will
121
+ use one BrownianTree per batch item, each with its own seed.
122
+ transform (callable): A function that maps sigma to the sampler's
123
+ internal timestep.
124
+ """
125
+
126
+ def __init__(self,
127
+ x,
128
+ sigma_min,
129
+ sigma_max,
130
+ seed=None,
131
+ transform=lambda x: x):
132
+ self.transform = transform
133
+ t0 = self.transform(torch.as_tensor(sigma_min))
134
+ t1 = self.transform(torch.as_tensor(sigma_max))
135
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
136
+
137
+ def __call__(self, sigma, sigma_next):
138
+ t0 = self.transform(torch.as_tensor(sigma))
139
+ t1 = self.transform(torch.as_tensor(sigma_next))
140
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
141
+
142
+
143
+ @torch.no_grad()
144
+ def sample_dpmpp_2m_sde(noise,
145
+ model,
146
+ sigmas,
147
+ eta=1.,
148
+ s_noise=1.,
149
+ solver_type='midpoint',
150
+ show_progress=True,
151
+ variant_info=None):
152
+ """
153
+ DPM-Solver++ (2M) SDE.
154
+ """
155
+ assert solver_type in {'heun', 'midpoint'}
156
+
157
+ x = noise * sigmas[0]
158
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[
159
+ sigmas < float('inf')].max()
160
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
161
+ old_denoised = None
162
+ h_last = None
163
+
164
+ for i in trange(len(sigmas) - 1, disable=not show_progress):
165
+ logger.info(f'step: {i}')
166
+ if sigmas[i] == float('inf'):
167
+ # Euler method
168
+ denoised = model(noise, sigmas[i], variant_info=variant_info)
169
+ x = denoised + sigmas[i + 1] * noise
170
+ else:
171
+ _, c_in = get_scalings(sigmas[i])
172
+ denoised = model(x * c_in, sigmas[i], variant_info=variant_info)
173
+ if sigmas[i + 1] == 0:
174
+ # Denoising step
175
+ x = denoised
176
+ else:
177
+ # DPM-Solver++(2M) SDE
178
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
179
+ h = s - t
180
+ eta_h = eta * h
181
+
182
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
183
+ (-h - eta_h).expm1().neg() * denoised
184
+
185
+ if old_denoised is not None:
186
+ r = h_last / h
187
+ if solver_type == 'heun':
188
+ x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
189
+ (1 / r) * (denoised - old_denoised)
190
+ elif solver_type == 'midpoint':
191
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * \
192
+ (1 / r) * (denoised - old_denoised)
193
+
194
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[
195
+ i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
196
+
197
+ old_denoised = denoised
198
+ h_last = h
199
+
200
+ if variant_info is not None and variant_info.get('type') == 'variant1':
201
+ x_long, x_short = x.chunk(2, dim=0)
202
+ x = x_long * (1-variant_info['alpha']) + x_short * variant_info['alpha']
203
+
204
+ return x
video_to_video/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .embedder import *
2
+ from .unet_v2v import *
3
+ # from .unet_v2v_deform import *
video_to_video/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (205 Bytes). View file
 
video_to_video/modules/__pycache__/embedder.cpython-39.pyc ADDED
Binary file (2.58 kB). View file
 
video_to_video/modules/__pycache__/t5.cpython-39.pyc ADDED
Binary file (7.07 kB). View file
 
video_to_video/modules/__pycache__/unet_v2v.cpython-39.pyc ADDED
Binary file (47.6 kB). View file
 
video_to_video/modules/__pycache__/unet_v2v_LocalConv.cpython-39.pyc ADDED
Binary file (47.8 kB). View file
 
video_to_video/modules/__pycache__/unet_v2v_deform.cpython-39.pyc ADDED
Binary file (48.2 kB). View file
 
video_to_video/modules/embedder.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import os
4
+
5
+ import numpy as np
6
+ import open_clip
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchvision.transforms as T
10
+
11
+
12
+ class FrozenOpenCLIPEmbedder(nn.Module):
13
+ """
14
+ Uses the OpenCLIP transformer encoder for text
15
+ """
16
+ LAYERS = ['last', 'penultimate']
17
+
18
+ def __init__(self,
19
+ pretrained='laion2b_s32b_b79k',
20
+ arch='ViT-H-14',
21
+ device='cuda',
22
+ max_length=77,
23
+ freeze=True,
24
+ layer='penultimate'):
25
+ super().__init__()
26
+ assert layer in self.LAYERS
27
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
28
+
29
+ del model.visual
30
+ self.model = model
31
+ self.device = device
32
+ self.max_length = max_length
33
+
34
+ if freeze:
35
+ self.freeze()
36
+ self.layer = layer
37
+ if self.layer == 'last':
38
+ self.layer_idx = 0
39
+ elif self.layer == 'penultimate':
40
+ self.layer_idx = 1
41
+ else:
42
+ raise NotImplementedError()
43
+
44
+ def freeze(self):
45
+ self.model = self.model.eval()
46
+ for param in self.parameters():
47
+ param.requires_grad = False
48
+
49
+ def forward(self, text):
50
+ tokens = open_clip.tokenize(text)
51
+ z = self.encode_with_transformer(tokens.to(self.device))
52
+ return z
53
+
54
+ def encode_with_transformer(self, text):
55
+ x = self.model.token_embedding(text)
56
+ x = x + self.model.positional_embedding
57
+ x = x.permute(1, 0, 2)
58
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
59
+ x = x.permute(1, 0, 2)
60
+ x = self.model.ln_final(x)
61
+ return x
62
+
63
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
64
+ for i, r in enumerate(self.model.transformer.resblocks):
65
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
66
+ break
67
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
68
+ ):
69
+ x = checkpoint(r, x, attn_mask)
70
+ else:
71
+ x = r(x, attn_mask=attn_mask)
72
+ return x
73
+
74
+ def encode(self, text):
75
+ return self(text)
video_to_video/modules/t5.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from PixArt
2
+ #
3
+ # Copyright (C) 2023 PixArt-alpha/PixArt-alpha
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU Affero General Public License as published
7
+ # by the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU Affero General Public License for more details.
14
+ #
15
+ #
16
+ # This source code is licensed under the license found in the
17
+ # LICENSE file in the root directory of this source tree.
18
+ # --------------------------------------------------------
19
+ # References:
20
+ # PixArt: https://github.com/PixArt-alpha/PixArt-alpha
21
+ # T5: https://github.com/google-research/text-to-text-transfer-transformer
22
+ # --------------------------------------------------------
23
+
24
+ import html
25
+ import re
26
+
27
+ import ftfy
28
+ import torch
29
+ from transformers import AutoTokenizer, T5EncoderModel
30
+
31
+ # from opensora.registry import MODELS
32
+
33
+
34
+ class T5Embedder:
35
+ def __init__(
36
+ self,
37
+ device,
38
+ from_pretrained=None,
39
+ *,
40
+ cache_dir=None,
41
+ hf_token=None,
42
+ use_text_preprocessing=True,
43
+ t5_model_kwargs=None,
44
+ torch_dtype=None,
45
+ use_offload_folder=None,
46
+ model_max_length=120,
47
+ local_files_only=False,
48
+ ):
49
+ self.device = torch.device(device)
50
+ self.torch_dtype = torch_dtype or torch.bfloat16
51
+ self.cache_dir = cache_dir
52
+
53
+ if t5_model_kwargs is None:
54
+ t5_model_kwargs = {
55
+ "low_cpu_mem_usage": True,
56
+ "torch_dtype": self.torch_dtype,
57
+ }
58
+
59
+ if use_offload_folder is not None:
60
+ t5_model_kwargs["offload_folder"] = use_offload_folder
61
+ t5_model_kwargs["device_map"] = {
62
+ "shared": self.device,
63
+ "encoder.embed_tokens": self.device,
64
+ "encoder.block.0": self.device,
65
+ "encoder.block.1": self.device,
66
+ "encoder.block.2": self.device,
67
+ "encoder.block.3": self.device,
68
+ "encoder.block.4": self.device,
69
+ "encoder.block.5": self.device,
70
+ "encoder.block.6": self.device,
71
+ "encoder.block.7": self.device,
72
+ "encoder.block.8": self.device,
73
+ "encoder.block.9": self.device,
74
+ "encoder.block.10": self.device,
75
+ "encoder.block.11": self.device,
76
+ "encoder.block.12": "disk",
77
+ "encoder.block.13": "disk",
78
+ "encoder.block.14": "disk",
79
+ "encoder.block.15": "disk",
80
+ "encoder.block.16": "disk",
81
+ "encoder.block.17": "disk",
82
+ "encoder.block.18": "disk",
83
+ "encoder.block.19": "disk",
84
+ "encoder.block.20": "disk",
85
+ "encoder.block.21": "disk",
86
+ "encoder.block.22": "disk",
87
+ "encoder.block.23": "disk",
88
+ "encoder.final_layer_norm": "disk",
89
+ "encoder.dropout": "disk",
90
+ }
91
+ else:
92
+ t5_model_kwargs["device_map"] = {
93
+ "shared": self.device,
94
+ "encoder": self.device,
95
+ }
96
+
97
+ self.use_text_preprocessing = use_text_preprocessing
98
+ self.hf_token = hf_token
99
+
100
+ self.tokenizer = AutoTokenizer.from_pretrained(
101
+ from_pretrained,
102
+ cache_dir=cache_dir,
103
+ local_files_only=local_files_only,
104
+ )
105
+ self.model = T5EncoderModel.from_pretrained(
106
+ from_pretrained,
107
+ cache_dir=cache_dir,
108
+ local_files_only=local_files_only,
109
+ **t5_model_kwargs,
110
+ ).eval()
111
+ self.model_max_length = model_max_length
112
+
113
+ def get_text_embeddings(self, texts):
114
+ text_tokens_and_mask = self.tokenizer(
115
+ texts,
116
+ max_length=self.model_max_length,
117
+ padding="max_length",
118
+ truncation=True,
119
+ return_attention_mask=True,
120
+ add_special_tokens=True,
121
+ return_tensors="pt",
122
+ )
123
+
124
+ input_ids = text_tokens_and_mask["input_ids"].to(self.device)
125
+ attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
126
+ with torch.no_grad():
127
+ text_encoder_embs = self.model(
128
+ input_ids=input_ids,
129
+ attention_mask=attention_mask,
130
+ )["last_hidden_state"].detach()
131
+ return text_encoder_embs, attention_mask
132
+
133
+
134
+ # @MODELS.register_module("t5")
135
+ class T5Encoder:
136
+ def __init__(
137
+ self,
138
+ from_pretrained=None,
139
+ model_max_length=120,
140
+ device="cuda",
141
+ dtype=torch.float,
142
+ cache_dir=None,
143
+ shardformer=False,
144
+ local_files_only=False,
145
+ ):
146
+ assert from_pretrained is not None, "Please specify the path to the T5 model"
147
+
148
+ self.t5 = T5Embedder(
149
+ device=device,
150
+ torch_dtype=dtype,
151
+ from_pretrained=from_pretrained,
152
+ cache_dir=cache_dir,
153
+ model_max_length=model_max_length,
154
+ local_files_only=local_files_only,
155
+ )
156
+ self.t5.model.to(dtype=dtype)
157
+ self.y_embedder = None
158
+
159
+ self.model_max_length = model_max_length
160
+ self.output_dim = self.t5.model.config.d_model
161
+ self.dtype = dtype
162
+
163
+ if shardformer:
164
+ self.shardformer_t5()
165
+
166
+ def shardformer_t5(self):
167
+ from colossalai.shardformer import ShardConfig, ShardFormer
168
+
169
+ from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy
170
+ from opensora.utils.misc import requires_grad
171
+
172
+ shard_config = ShardConfig(
173
+ tensor_parallel_process_group=None,
174
+ pipeline_stage_manager=None,
175
+ enable_tensor_parallelism=False,
176
+ enable_fused_normalization=False,
177
+ enable_flash_attention=False,
178
+ enable_jit_fused=True,
179
+ enable_sequence_parallelism=False,
180
+ enable_sequence_overlap=False,
181
+ )
182
+ shard_former = ShardFormer(shard_config=shard_config)
183
+ optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
184
+ self.t5.model = optim_model.to(self.dtype)
185
+
186
+ # ensure the weights are frozen
187
+ requires_grad(self.t5.model, False)
188
+
189
+ def encode(self, text):
190
+ caption_embs, emb_masks = self.t5.get_text_embeddings(text)
191
+ caption_embs = caption_embs[:, None]
192
+ return dict(y=caption_embs, mask=emb_masks)
193
+
194
+ def null(self, n):
195
+ null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
196
+ return null_y
197
+
198
+
199
+ def basic_clean(text):
200
+ text = ftfy.fix_text(text)
201
+ text = html.unescape(html.unescape(text))
202
+ return text.strip()
203
+
204
+
205
+ BAD_PUNCT_REGEX = re.compile(
206
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
207
+ ) # noqa
208
+
209
+
210
+ def clean_caption(caption):
211
+ import urllib.parse as ul
212
+
213
+ from bs4 import BeautifulSoup
214
+
215
+ caption = str(caption)
216
+ caption = ul.unquote_plus(caption)
217
+ caption = caption.strip().lower()
218
+ caption = re.sub("<person>", "person", caption)
219
+ # urls:
220
+ caption = re.sub(
221
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
222
+ "",
223
+ caption,
224
+ ) # regex for urls
225
+ caption = re.sub(
226
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
227
+ "",
228
+ caption,
229
+ ) # regex for urls
230
+ # html:
231
+ caption = BeautifulSoup(caption, features="html.parser").text
232
+
233
+ # @<nickname>
234
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
235
+
236
+ # 31C0—31EF CJK Strokes
237
+ # 31F0—31FF Katakana Phonetic Extensions
238
+ # 3200—32FF Enclosed CJK Letters and Months
239
+ # 3300—33FF CJK Compatibility
240
+ # 3400—4DBF CJK Unified Ideographs Extension A
241
+ # 4DC0—4DFF Yijing Hexagram Symbols
242
+ # 4E00—9FFF CJK Unified Ideographs
243
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
244
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
245
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
246
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
247
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
248
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
249
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
250
+ #######################################################
251
+
252
+ # все виды тире / all types of dash --> "-"
253
+ caption = re.sub(
254
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
255
+ "-",
256
+ caption,
257
+ )
258
+
259
+ # кавычки к одному стандарту
260
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
261
+ caption = re.sub(r"[‘’]", "'", caption)
262
+
263
+ # &quot;
264
+ caption = re.sub(r"&quot;?", "", caption)
265
+ # &amp
266
+ caption = re.sub(r"&amp", "", caption)
267
+
268
+ # ip adresses:
269
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
270
+
271
+ # article ids:
272
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
273
+
274
+ # \n
275
+ caption = re.sub(r"\\n", " ", caption)
276
+
277
+ # "#123"
278
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
279
+ # "#12345.."
280
+ caption = re.sub(r"#\d{5,}\b", "", caption)
281
+ # "123456.."
282
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
283
+ # filenames:
284
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
285
+
286
+ #
287
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
288
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
289
+
290
+ caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
291
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
292
+
293
+ # this-is-my-cute-cat / this_is_my_cute_cat
294
+ regex2 = re.compile(r"(?:\-|\_)")
295
+ if len(re.findall(regex2, caption)) > 3:
296
+ caption = re.sub(regex2, " ", caption)
297
+
298
+ caption = basic_clean(caption)
299
+
300
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
301
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
302
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
303
+
304
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
305
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
306
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
307
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
308
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
309
+
310
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
311
+
312
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
313
+
314
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
315
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
316
+ caption = re.sub(r"\s+", " ", caption)
317
+
318
+ caption.strip()
319
+
320
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
321
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
322
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
323
+ caption = re.sub(r"^\.\S+$", "", caption)
324
+
325
+ return caption.strip()
326
+
327
+
328
+ def text_preprocessing(text, use_text_preprocessing: bool = True):
329
+ if use_text_preprocessing:
330
+ # The exact text cleaning as was in the training stage:
331
+ text = clean_caption(text)
332
+ text = clean_caption(text)
333
+ return text
334
+ else:
335
+ return text.lower().strip()
video_to_video/modules/unet_v2v.py ADDED
@@ -0,0 +1,2332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import math
4
+ import os
5
+ from abc import abstractmethod
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import xformers
11
+ import xformers.ops
12
+ from einops import rearrange
13
+ from fairscale.nn.checkpoint import checkpoint_wrapper
14
+ from timm.models.vision_transformer import Mlp
15
+
16
+
17
+ USE_TEMPORAL_TRANSFORMER = True
18
+
19
+
20
+ class CaptionEmbedder(nn.Module):
21
+ """
22
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
23
+ """
24
+
25
+ def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120):
26
+ super().__init__()
27
+ self.y_proj = Mlp(
28
+ in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0
29
+ )
30
+ self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5))
31
+ self.uncond_prob = uncond_prob
32
+
33
+ def token_drop(self, caption, force_drop_ids=None):
34
+ """
35
+ Drops labels to enable classifier-free guidance.
36
+ """
37
+ if force_drop_ids is None:
38
+ drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
39
+ else:
40
+ drop_ids = force_drop_ids == 1
41
+ caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
42
+ return caption
43
+
44
+ def forward(self, caption, train, force_drop_ids=None):
45
+ if train:
46
+ assert caption.shape[2:] == self.y_embedding.shape
47
+ use_dropout = self.uncond_prob > 0
48
+ if (train and use_dropout) or (force_drop_ids is not None):
49
+ caption = self.token_drop(caption, force_drop_ids)
50
+ caption = self.y_proj(caption)
51
+ return caption
52
+
53
+
54
+ class DropPath(nn.Module):
55
+ r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
56
+ """
57
+
58
+ def __init__(self, p):
59
+ super(DropPath, self).__init__()
60
+ self.p = p
61
+
62
+ def forward(self, *args, zero=None, keep=None):
63
+ if not self.training:
64
+ return args[0] if len(args) == 1 else args
65
+
66
+ # params
67
+ x = args[0]
68
+ b = x.size(0)
69
+ n = (torch.rand(b) < self.p).sum()
70
+
71
+ # non-zero and non-keep mask
72
+ mask = x.new_ones(b, dtype=torch.bool)
73
+ if keep is not None:
74
+ mask[keep] = False
75
+ if zero is not None:
76
+ mask[zero] = False
77
+
78
+ # drop-path index
79
+ index = torch.where(mask)[0]
80
+ index = index[torch.randperm(len(index))[:n]]
81
+ if zero is not None:
82
+ index = torch.cat([index, torch.where(zero)[0]], dim=0)
83
+
84
+ # drop-path multiplier
85
+ multiplier = x.new_ones(b)
86
+ multiplier[index] = 0.0
87
+ output = tuple(u * self.broadcast(multiplier, u) for u in args)
88
+ return output[0] if len(args) == 1 else output
89
+
90
+ def broadcast(self, src, dst):
91
+ assert src.size(0) == dst.size(0)
92
+ shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
93
+ return src.view(shape)
94
+
95
+
96
+ def sinusoidal_embedding(timesteps, dim):
97
+ # check input
98
+ half = dim // 2
99
+ timesteps = timesteps.float()
100
+
101
+ # compute sinusoidal embedding
102
+ sinusoid = torch.outer(
103
+ timesteps, torch.pow(10000,
104
+ -torch.arange(half).to(timesteps).div(half)))
105
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
106
+ if dim % 2 != 0:
107
+ x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
108
+ return x
109
+
110
+
111
+ def exists(x):
112
+ return x is not None
113
+
114
+
115
+ def default(val, d):
116
+ if exists(val):
117
+ return val
118
+ return d() if callable(d) else d
119
+
120
+
121
+ def prob_mask_like(shape, prob, device):
122
+ if prob == 1:
123
+ return torch.ones(shape, device=device, dtype=torch.bool)
124
+ elif prob == 0:
125
+ return torch.zeros(shape, device=device, dtype=torch.bool)
126
+ else:
127
+ mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
128
+ # aviod mask all, which will cause find_unused_parameters error
129
+ if mask.all():
130
+ mask[0] = False
131
+ return mask
132
+
133
+
134
+ class MemoryEfficientCrossAttention(nn.Module):
135
+
136
+ def __init__(self,
137
+ query_dim,
138
+ context_dim=None,
139
+ heads=8,
140
+ dim_head=64,
141
+ max_bs=16384,
142
+ dropout=0.0):
143
+ super().__init__()
144
+ inner_dim = dim_head * heads
145
+ context_dim = default(context_dim, query_dim)
146
+
147
+ self.max_bs = max_bs
148
+ self.heads = heads
149
+ self.dim_head = dim_head
150
+
151
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
152
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
153
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
154
+ self.to_out = nn.Sequential(
155
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
156
+ self.attention_op: Optional[Any] = None
157
+
158
+ def forward(self, x, context=None, mask=None):
159
+ q = self.to_q(x)
160
+ context = default(context, x)
161
+ k = self.to_k(context)
162
+ v = self.to_v(context)
163
+
164
+ b, _, _ = q.shape
165
+ q, k, v = map(
166
+ lambda t: t.unsqueeze(3).reshape(b, t.shape[
167
+ 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
168
+ b * self.heads, t.shape[1], self.dim_head).contiguous(),
169
+ (q, k, v),
170
+ )
171
+
172
+ # actually compute the attention, what we cannot get enough of.
173
+ if q.shape[0] > self.max_bs:
174
+ q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0)
175
+ k_list = torch.chunk(k, k.shape[0] // self.max_bs, dim=0)
176
+ v_list = torch.chunk(v, v.shape[0] // self.max_bs, dim=0)
177
+ out_list = []
178
+ for q_1, k_1, v_1 in zip(q_list, k_list, v_list):
179
+ out = xformers.ops.memory_efficient_attention(
180
+ q_1, k_1, v_1, attn_bias=None, op=self.attention_op)
181
+ out_list.append(out)
182
+ out = torch.cat(out_list, dim=0)
183
+ else:
184
+ out = xformers.ops.memory_efficient_attention(
185
+ q, k, v, attn_bias=None, op=self.attention_op)
186
+
187
+ if exists(mask):
188
+ raise NotImplementedError
189
+ out = (
190
+ out.unsqueeze(0).reshape(
191
+ b, self.heads, out.shape[1],
192
+ self.dim_head).permute(0, 2, 1,
193
+ 3).reshape(b, out.shape[1],
194
+ self.heads * self.dim_head))
195
+ return self.to_out(out)
196
+
197
+
198
+ class RelativePositionBias(nn.Module):
199
+
200
+ def __init__(self, heads=8, num_buckets=32, max_distance=128):
201
+ super().__init__()
202
+ self.num_buckets = num_buckets
203
+ self.max_distance = max_distance
204
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
205
+
206
+ @staticmethod
207
+ def _relative_position_bucket(relative_position,
208
+ num_buckets=32,
209
+ max_distance=128):
210
+ ret = 0
211
+ n = -relative_position
212
+
213
+ num_buckets //= 2
214
+ ret += (n < 0).long() * num_buckets
215
+ n = torch.abs(n)
216
+
217
+ max_exact = num_buckets // 2
218
+ is_small = n < max_exact
219
+
220
+ val_if_large = max_exact + (
221
+ torch.log(n.float() / max_exact)
222
+ / math.log(max_distance / max_exact) * # noqa
223
+ (num_buckets - max_exact)).long()
224
+ val_if_large = torch.min(
225
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1))
226
+
227
+ ret += torch.where(is_small, n, val_if_large)
228
+ return ret
229
+
230
+ def forward(self, n, device):
231
+ q_pos = torch.arange(n, dtype=torch.long, device=device)
232
+ k_pos = torch.arange(n, dtype=torch.long, device=device)
233
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
234
+ rp_bucket = self._relative_position_bucket(
235
+ rel_pos,
236
+ num_buckets=self.num_buckets,
237
+ max_distance=self.max_distance)
238
+ values = self.relative_attention_bias(rp_bucket)
239
+ return rearrange(values, 'i j h -> h i j')
240
+
241
+
242
+ class SpatialTransformer(nn.Module):
243
+ """
244
+ Transformer block for image-like data.
245
+ First, project the input (aka embedding)
246
+ and reshape to b, t, d.
247
+ Then apply standard transformer action.
248
+ Finally, reshape to image
249
+ NEW: use_linear for more efficiency instead of the 1x1 convs
250
+ """
251
+
252
+ def __init__(self,
253
+ in_channels,
254
+ n_heads,
255
+ d_head,
256
+ depth=1,
257
+ dropout=0.,
258
+ context_dim=None,
259
+ disable_self_attn=False,
260
+ use_linear=False,
261
+ use_checkpoint=True,
262
+ is_ctrl=False):
263
+ super().__init__()
264
+ if exists(context_dim) and not isinstance(context_dim, list):
265
+ context_dim = [context_dim]
266
+ self.in_channels = in_channels
267
+ inner_dim = n_heads * d_head
268
+ self.norm = torch.nn.GroupNorm(
269
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
270
+ if not use_linear:
271
+ self.proj_in = nn.Conv2d(
272
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
273
+ else:
274
+ self.proj_in = nn.Linear(in_channels, inner_dim)
275
+
276
+ self.transformer_blocks = nn.ModuleList([
277
+ BasicTransformerBlock(
278
+ inner_dim,
279
+ n_heads,
280
+ d_head,
281
+ dropout=dropout,
282
+ context_dim=context_dim[d],
283
+ disable_self_attn=disable_self_attn,
284
+ checkpoint=use_checkpoint,
285
+ local_type='space',
286
+ is_ctrl=is_ctrl) for d in range(depth)
287
+ ])
288
+ if not use_linear:
289
+ self.proj_out = zero_module(
290
+ nn.Conv2d(
291
+ inner_dim, in_channels, kernel_size=1, stride=1,
292
+ padding=0))
293
+ else:
294
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
295
+ self.use_linear = use_linear
296
+
297
+ def forward(self, x, context=None):
298
+ # note: if no context is given, cross-attention defaults to self-attention
299
+ if not isinstance(context, list):
300
+ context = [context]
301
+ _, _, h, w = x.shape
302
+ # print('x shape:', x.shape) # [64, 320, 90, 160]
303
+ x_in = x
304
+ x = self.norm(x)
305
+ if not self.use_linear:
306
+ x = self.proj_in(x)
307
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
308
+ if self.use_linear:
309
+ x = self.proj_in(x)
310
+ for i, block in enumerate(self.transformer_blocks):
311
+ x = block(x, context=context[i], h=h, w=w)
312
+ if self.use_linear:
313
+ x = self.proj_out(x)
314
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
315
+ if not self.use_linear:
316
+ x = self.proj_out(x)
317
+ return x + x_in
318
+
319
+
320
+ _ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
321
+
322
+
323
+ class CrossAttention(nn.Module):
324
+
325
+ def __init__(self,
326
+ query_dim,
327
+ context_dim=None,
328
+ heads=8,
329
+ dim_head=64,
330
+ dropout=0.):
331
+ super().__init__()
332
+ inner_dim = dim_head * heads
333
+ context_dim = default(context_dim, query_dim)
334
+
335
+ self.scale = dim_head**-0.5
336
+ self.heads = heads
337
+
338
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
339
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
340
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
341
+
342
+ self.to_out = nn.Sequential(
343
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
344
+
345
+ def forward(self, x, context=None, mask=None):
346
+ h = self.heads
347
+
348
+ q = self.to_q(x)
349
+ context = default(context, x)
350
+ k = self.to_k(context)
351
+ v = self.to_v(context)
352
+
353
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
354
+ (q, k, v))
355
+
356
+ # force cast to fp32 to avoid overflowing
357
+ if _ATTN_PRECISION == 'fp32':
358
+ with torch.autocast(enabled=False, device_type='cuda'):
359
+ q, k = q.float(), k.float()
360
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
361
+ else:
362
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
363
+
364
+ del q, k
365
+
366
+ if exists(mask):
367
+ mask = rearrange(mask, 'b ... -> b (...)')
368
+ max_neg_value = -torch.finfo(sim.dtype).max
369
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
370
+ sim.masked_fill_(~mask, max_neg_value)
371
+
372
+ # attention, what we cannot get enough of
373
+ sim = sim.softmax(dim=-1)
374
+
375
+ out = torch.einsum('b i j, b j d -> b i d', sim, v)
376
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
377
+ return self.to_out(out)
378
+
379
+
380
+
381
+
382
+ class SpatialAttention(nn.Module):
383
+ def __init__(self):
384
+ super(SpatialAttention, self).__init__()
385
+ self.conv1 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, padding=7 // 2, bias=False)
386
+ self.sigmoid = nn.Sigmoid()
387
+ def forward(self, x):
388
+
389
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
390
+ avg_out = torch.mean(x, dim=1, keepdim=True)
391
+
392
+ weight = torch.cat([max_out, avg_out], dim=1)
393
+ weight = self.conv1(weight)
394
+
395
+ out = self.sigmoid(weight) * x
396
+ return out
397
+
398
+ class TemporalLocalAttention(nn.Module): # b c t h w
399
+ def __init__(self, dim, kernel_size=7):
400
+ super(TemporalLocalAttention, self).__init__()
401
+ self.conv1 = nn.Linear(in_features=2, out_features=1, bias=False)
402
+ self.sigmoid = nn.Sigmoid()
403
+
404
+ def forward(self, x):
405
+
406
+ max_out, _ = torch.max(x, dim=-1, keepdim=True)
407
+ avg_out = torch.mean(x, dim=-1, keepdim=True)
408
+
409
+ weight = torch.cat([max_out, avg_out], dim=-1)
410
+ weight = self.conv1(weight)
411
+
412
+ out = self.sigmoid(weight) * x
413
+ return out
414
+
415
+
416
+ class BasicTransformerBlock(nn.Module):
417
+
418
+ def __init__(self,
419
+ dim,
420
+ n_heads,
421
+ d_head,
422
+ dropout=0.,
423
+ context_dim=None,
424
+ gated_ff=True,
425
+ checkpoint=True,
426
+ disable_self_attn=False,
427
+ local_type=None,
428
+ is_ctrl=False):
429
+ super().__init__()
430
+ self.local_type = local_type
431
+ self.is_ctrl = is_ctrl
432
+ attn_cls = MemoryEfficientCrossAttention
433
+ self.disable_self_attn = disable_self_attn
434
+ self.attn1 = attn_cls( # self-attn
435
+ query_dim=dim,
436
+ heads=n_heads,
437
+ dim_head=d_head,
438
+ dropout=dropout,
439
+ context_dim=context_dim if self.disable_self_attn else None)
440
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
441
+
442
+ attn_cls2 = MemoryEfficientCrossAttention
443
+
444
+ self.attn2 = attn_cls2(
445
+ query_dim=dim,
446
+ context_dim=context_dim,
447
+ heads=n_heads,
448
+ dim_head=d_head,
449
+ dropout=dropout)
450
+ self.norm1 = nn.LayerNorm(dim)
451
+ self.norm2 = nn.LayerNorm(dim)
452
+ self.norm3 = nn.LayerNorm(dim)
453
+ self.checkpoint = checkpoint
454
+
455
+ if self.local_type == 'space' and self.is_ctrl:
456
+ self.local1 = SpatialAttention()
457
+
458
+ if self.local_type == 'temp' and self.is_ctrl:
459
+ self.local1 = TemporalLocalAttention(dim=dim)
460
+ self.local2 = TemporalLocalAttention(dim=dim)
461
+
462
+ def forward_(self, x, context=None):
463
+ return checkpoint(self._forward, (x, context), self.parameters(),
464
+ self.checkpoint)
465
+
466
+ def forward(self, x, context=None, h=None, w=None):
467
+
468
+ if self.local_type == 'space' and self.is_ctrl: # [b*t,(hw), c]
469
+
470
+ x_local = rearrange(x, 'b (h w) c -> b c h w', h=h)
471
+ x_local = self.local1(x_local)
472
+ x_local = rearrange(x_local, 'b c h w -> b (h w) c')
473
+
474
+ x = self.attn1(
475
+ self.norm1(x_local),
476
+ context=context if self.disable_self_attn else None) + x
477
+
478
+ x = self.attn2(self.norm2(x), context=context) + x # cross attention or self-attention
479
+ x = self.ff(self.norm3(x)) + x
480
+
481
+ if self.local_type == 'temp' and self.is_ctrl:
482
+
483
+ # x_local = rearrange(x, '(b h w) t c -> b c t h w', h=h, w=w)
484
+ x_local = self.local1(x)
485
+
486
+ x = self.attn1(
487
+ self.norm1(x_local),
488
+ context=context if self.disable_self_attn else None) + x
489
+
490
+ # x_local = rearrange(x, '(b h w) t c -> b c t h w', h=h, w=w)
491
+ x_local = self.local2(x)
492
+
493
+ x = self.attn2(self.norm2(x_local), context=context) + x
494
+ x = self.ff(self.norm3(x)) + x
495
+
496
+ # elif self.local_type == 'space' and self.is_ctrl:
497
+ # # print('*** use original attention ***')
498
+ # x = self.attn1(
499
+ # self.norm1(x),
500
+ # context=context if self.disable_self_attn else None) + x # self-attention
501
+
502
+ # x = self.attn2(self.norm2(x), context=context) + x # cross attention or self-attention
503
+ # x = self.ff(self.norm3(x)) + x
504
+
505
+ return x
506
+
507
+
508
+ # feedforward
509
+ class GEGLU(nn.Module):
510
+
511
+ def __init__(self, dim_in, dim_out):
512
+ super().__init__()
513
+ self.proj = nn.Linear(dim_in, dim_out * 2)
514
+
515
+ def forward(self, x):
516
+ x, gate = self.proj(x).chunk(2, dim=-1)
517
+ return x * F.gelu(gate)
518
+
519
+
520
+ def zero_module(module):
521
+ """
522
+ Zero out the parameters of a module and return it.
523
+ """
524
+ for p in module.parameters():
525
+ p.detach().zero_()
526
+ return module
527
+
528
+
529
+ class FeedForward(nn.Module):
530
+
531
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
532
+ super().__init__()
533
+ inner_dim = int(dim * mult)
534
+ dim_out = default(dim_out, dim)
535
+ project_in = nn.Sequential(nn.Linear(
536
+ dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
537
+
538
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout),
539
+ nn.Linear(inner_dim, dim_out))
540
+
541
+ def forward(self, x):
542
+ return self.net(x)
543
+
544
+
545
+ class Upsample(nn.Module):
546
+ """
547
+ An upsampling layer with an optional convolution.
548
+ :param channels: channels in the inputs and outputs.
549
+ :param use_conv: a bool determining if a convolution is applied.
550
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
551
+ upsampling occurs in the inner-two dimensions.
552
+ """
553
+
554
+ def __init__(self,
555
+ channels,
556
+ use_conv,
557
+ dims=2,
558
+ out_channels=None,
559
+ padding=1):
560
+ super().__init__()
561
+ self.channels = channels
562
+ self.out_channels = out_channels or channels
563
+ self.use_conv = use_conv
564
+ self.dims = dims
565
+ if use_conv:
566
+ self.conv = nn.Conv2d(
567
+ self.channels, self.out_channels, 3, padding=padding)
568
+
569
+ def forward(self, x):
570
+ assert x.shape[1] == self.channels
571
+ if self.dims == 3:
572
+ x = F.interpolate(
573
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
574
+ mode='nearest')
575
+ else:
576
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
577
+ x = x[..., 1:-1, :]
578
+ if self.use_conv:
579
+ x = self.conv(x)
580
+ return x
581
+
582
+
583
+ class ResBlock(nn.Module):
584
+ """
585
+ A residual block that can optionally change the number of channels.
586
+ :param channels: the number of input channels.
587
+ :param emb_channels: the number of timestep embedding channels.
588
+ :param dropout: the rate of dropout.
589
+ :param out_channels: if specified, the number of out channels.
590
+ :param use_conv: if True and out_channels is specified, use a spatial
591
+ convolution instead of a smaller 1x1 convolution to change the
592
+ channels in the skip connection.
593
+ :param dims: determines if the signal is 1D, 2D, or 3D.
594
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
595
+ :param up: if True, use this block for upsampling.
596
+ :param down: if True, use this block for downsampling.
597
+ """
598
+
599
+ def __init__(
600
+ self,
601
+ channels,
602
+ emb_channels,
603
+ dropout,
604
+ out_channels=None,
605
+ use_conv=False,
606
+ use_scale_shift_norm=False,
607
+ dims=2,
608
+ up=False,
609
+ down=False,
610
+ use_temporal_conv=True,
611
+ use_image_dataset=False,
612
+ ):
613
+ super().__init__()
614
+ self.channels = channels
615
+ self.emb_channels = emb_channels
616
+ self.dropout = dropout
617
+ self.out_channels = out_channels or channels
618
+ self.use_conv = use_conv
619
+ self.use_scale_shift_norm = use_scale_shift_norm
620
+ self.use_temporal_conv = use_temporal_conv
621
+
622
+ self.in_layers = nn.Sequential(
623
+ nn.GroupNorm(32, channels),
624
+ nn.SiLU(),
625
+ nn.Conv2d(channels, self.out_channels, 3, padding=1),
626
+ )
627
+
628
+ self.updown = up or down
629
+
630
+ if up:
631
+ self.h_upd = Upsample(channels, False, dims)
632
+ self.x_upd = Upsample(channels, False, dims)
633
+ elif down:
634
+ self.h_upd = Downsample(channels, False, dims)
635
+ self.x_upd = Downsample(channels, False, dims)
636
+ else:
637
+ self.h_upd = self.x_upd = nn.Identity()
638
+
639
+ self.emb_layers = nn.Sequential(
640
+ nn.SiLU(),
641
+ nn.Linear(
642
+ emb_channels,
643
+ 2 * self.out_channels
644
+ if use_scale_shift_norm else self.out_channels,
645
+ ),
646
+ )
647
+ self.out_layers = nn.Sequential(
648
+ nn.GroupNorm(32, self.out_channels),
649
+ nn.SiLU(),
650
+ nn.Dropout(p=dropout),
651
+ zero_module(
652
+ nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
653
+ )
654
+
655
+ if self.out_channels == channels:
656
+ self.skip_connection = nn.Identity()
657
+ elif use_conv:
658
+ self.skip_connection = conv_nd(
659
+ dims, channels, self.out_channels, 3, padding=1)
660
+ else:
661
+ self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
662
+
663
+ if self.use_temporal_conv:
664
+ self.temopral_conv = TemporalConvBlock_v2(
665
+ self.out_channels,
666
+ self.out_channels,
667
+ dropout=0.1,
668
+ use_image_dataset=use_image_dataset)
669
+
670
+ def forward(self, x, emb, batch_size, variant_info=None):
671
+ """
672
+ Apply the block to a Tensor, conditioned on a timestep embedding.
673
+ :param x: an [N x C x ...] Tensor of features.
674
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
675
+ :return: an [N x C x ...] Tensor of outputs.
676
+ """
677
+ return self._forward(x, emb, batch_size, variant_info)
678
+
679
+ def _forward(self, x, emb, batch_size, variant_info):
680
+ if self.updown:
681
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
682
+ h = in_rest(x)
683
+ h = self.h_upd(h)
684
+ x = self.x_upd(x)
685
+ h = in_conv(h)
686
+ else:
687
+ h = self.in_layers(x)
688
+ emb_out = self.emb_layers(emb).type(h.dtype)
689
+ while len(emb_out.shape) < len(h.shape):
690
+ emb_out = emb_out[..., None]
691
+ if self.use_scale_shift_norm:
692
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
693
+ scale, shift = th.chunk(emb_out, 2, dim=1)
694
+ h = out_norm(h) * (1 + scale) + shift
695
+ h = out_rest(h)
696
+ else:
697
+ h = h + emb_out
698
+ h = self.out_layers(h)
699
+ h = self.skip_connection(x) + h
700
+
701
+ if self.use_temporal_conv:
702
+ h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size)
703
+ h = self.temopral_conv(h, variant_info=variant_info)
704
+ h = rearrange(h, 'b c f h w -> (b f) c h w')
705
+ return h
706
+
707
+
708
+ class Downsample(nn.Module):
709
+ """
710
+ A downsampling layer with an optional convolution.
711
+ :param channels: channels in the inputs and outputs.
712
+ :param use_conv: a bool determining if a convolution is applied.
713
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
714
+ downsampling occurs in the inner-two dimensions.
715
+ """
716
+
717
+ def __init__(self,
718
+ channels,
719
+ use_conv,
720
+ dims=2,
721
+ out_channels=None,
722
+ padding=(2, 1)):
723
+ super().__init__()
724
+ self.channels = channels
725
+ self.out_channels = out_channels or channels
726
+ self.use_conv = use_conv
727
+ self.dims = dims
728
+ stride = 2 if dims != 3 else (1, 2, 2)
729
+ if use_conv:
730
+ self.op = nn.Conv2d(
731
+ self.channels,
732
+ self.out_channels,
733
+ 3,
734
+ stride=stride,
735
+ padding=padding)
736
+ else:
737
+ assert self.channels == self.out_channels
738
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
739
+
740
+ def forward(self, x):
741
+ assert x.shape[1] == self.channels
742
+ return self.op(x)
743
+
744
+
745
+ class Resample(nn.Module):
746
+
747
+ def __init__(self, in_dim, out_dim, mode):
748
+ assert mode in ['none', 'upsample', 'downsample']
749
+ super(Resample, self).__init__()
750
+ self.in_dim = in_dim
751
+ self.out_dim = out_dim
752
+ self.mode = mode
753
+
754
+ def forward(self, x, reference=None):
755
+ if self.mode == 'upsample':
756
+ assert reference is not None
757
+ x = F.interpolate(x, size=reference.shape[-2:], mode='nearest')
758
+ elif self.mode == 'downsample':
759
+ x = F.adaptive_avg_pool2d(
760
+ x, output_size=tuple(u // 2 for u in x.shape[-2:]))
761
+ return x
762
+
763
+
764
+ class ResidualBlock(nn.Module):
765
+
766
+ def __init__(self,
767
+ in_dim,
768
+ embed_dim,
769
+ out_dim,
770
+ use_scale_shift_norm=True,
771
+ mode='none',
772
+ dropout=0.0):
773
+ super(ResidualBlock, self).__init__()
774
+ self.in_dim = in_dim
775
+ self.embed_dim = embed_dim
776
+ self.out_dim = out_dim
777
+ self.use_scale_shift_norm = use_scale_shift_norm
778
+ self.mode = mode
779
+
780
+ # layers
781
+ self.layer1 = nn.Sequential(
782
+ nn.GroupNorm(32, in_dim), nn.SiLU(),
783
+ nn.Conv2d(in_dim, out_dim, 3, padding=1))
784
+ self.resample = Resample(in_dim, in_dim, mode)
785
+ self.embedding = nn.Sequential(
786
+ nn.SiLU(),
787
+ nn.Linear(embed_dim,
788
+ out_dim * 2 if use_scale_shift_norm else out_dim))
789
+ self.layer2 = nn.Sequential(
790
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
791
+ nn.Conv2d(out_dim, out_dim, 3, padding=1))
792
+ self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
793
+ in_dim, out_dim, 1)
794
+
795
+ # zero out the last layer params
796
+ nn.init.zeros_(self.layer2[-1].weight)
797
+
798
+ def forward(self, x, e, reference=None):
799
+ identity = self.resample(x, reference)
800
+ x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference))
801
+ e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
802
+ if self.use_scale_shift_norm:
803
+ scale, shift = e.chunk(2, dim=1)
804
+ x = self.layer2[0](x) * (1 + scale) + shift
805
+ x = self.layer2[1:](x)
806
+ else:
807
+ x = x + e
808
+ x = self.layer2(x)
809
+ x = x + self.shortcut(identity)
810
+ return x
811
+
812
+
813
+ class AttentionBlock(nn.Module):
814
+
815
+ def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
816
+ # consider head_dim first, then num_heads
817
+ num_heads = dim // head_dim if head_dim else num_heads
818
+ head_dim = dim // num_heads
819
+ assert num_heads * head_dim == dim
820
+ super(AttentionBlock, self).__init__()
821
+ self.dim = dim
822
+ self.context_dim = context_dim
823
+ self.num_heads = num_heads
824
+ self.head_dim = head_dim
825
+ self.scale = math.pow(head_dim, -0.25)
826
+
827
+ # layers
828
+ self.norm = nn.GroupNorm(32, dim)
829
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
830
+ if context_dim is not None:
831
+ self.context_kv = nn.Linear(context_dim, dim * 2)
832
+ self.proj = nn.Conv2d(dim, dim, 1)
833
+
834
+ # zero out the last layer params
835
+ nn.init.zeros_(self.proj.weight)
836
+
837
+ def forward(self, x, context=None):
838
+ r"""x: [B, C, H, W].
839
+ context: [B, L, C] or None.
840
+ """
841
+ identity = x
842
+ b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
843
+
844
+ # compute query, key, value
845
+ x = self.norm(x)
846
+ q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
847
+ if context is not None:
848
+ ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
849
+ d).permute(0, 2, 3,
850
+ 1).chunk(
851
+ 2, dim=1)
852
+ k = torch.cat([ck, k], dim=-1)
853
+ v = torch.cat([cv, v], dim=-1)
854
+
855
+ # compute attention
856
+ attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
857
+ attn = F.softmax(attn, dim=-1)
858
+
859
+ # gather context
860
+ x = torch.matmul(v, attn.transpose(-1, -2))
861
+ x = x.reshape(b, c, h, w)
862
+
863
+ # output
864
+ x = self.proj(x)
865
+ return x + identity
866
+
867
+
868
+ class TemporalAttentionBlock(nn.Module):
869
+
870
+ def __init__(self,
871
+ dim,
872
+ heads=4,
873
+ dim_head=32,
874
+ rotary_emb=None,
875
+ use_image_dataset=False,
876
+ use_sim_mask=False):
877
+ super().__init__()
878
+ # consider num_heads first, as pos_bias needs fixed num_heads
879
+ dim_head = dim // heads
880
+ assert heads * dim_head == dim
881
+ self.use_image_dataset = use_image_dataset
882
+ self.use_sim_mask = use_sim_mask
883
+
884
+ self.scale = dim_head**-0.5
885
+ self.heads = heads
886
+ hidden_dim = dim_head * heads
887
+
888
+ self.norm = nn.GroupNorm(32, dim)
889
+ self.rotary_emb = rotary_emb
890
+ self.to_qkv = nn.Linear(dim, hidden_dim * 3)
891
+ self.to_out = nn.Linear(hidden_dim, dim)
892
+
893
+ def forward(self,
894
+ x,
895
+ pos_bias=None,
896
+ focus_present_mask=None,
897
+ video_mask=None):
898
+
899
+ identity = x
900
+ n, height, device = x.shape[2], x.shape[-2], x.device
901
+
902
+ x = self.norm(x)
903
+ x = rearrange(x, 'b c f h w -> b (h w) f c')
904
+
905
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
906
+
907
+ if exists(focus_present_mask) and focus_present_mask.all():
908
+ # if all batch samples are focusing on present
909
+ # it would be equivalent to passing that token's values (v=qkv[-1]) through to the output
910
+ values = qkv[-1]
911
+ out = self.to_out(values)
912
+ out = rearrange(out, 'b (h w) f c -> b c f h w', h=height)
913
+
914
+ return out + identity
915
+
916
+ # split out heads
917
+ q = rearrange(qkv[0], '... n (h d) -> ... h n d', h=self.heads)
918
+ k = rearrange(qkv[1], '... n (h d) -> ... h n d', h=self.heads)
919
+ v = rearrange(qkv[2], '... n (h d) -> ... h n d', h=self.heads)
920
+
921
+ # scale
922
+
923
+ q = q * self.scale
924
+
925
+ # rotate positions into queries and keys for time attention
926
+ if exists(self.rotary_emb):
927
+ q = self.rotary_emb.rotate_queries_or_keys(q)
928
+ k = self.rotary_emb.rotate_queries_or_keys(k)
929
+
930
+ # similarity
931
+ # shape [b (hw) h n n], n=f
932
+ sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
933
+
934
+ # relative positional bias
935
+
936
+ if exists(pos_bias):
937
+ sim = sim + pos_bias
938
+
939
+ if (focus_present_mask is None and video_mask is not None):
940
+ # video_mask: [B, n]
941
+ mask = video_mask[:, None, :] * video_mask[:, :, None]
942
+ mask = mask.unsqueeze(1).unsqueeze(1)
943
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
944
+ elif exists(focus_present_mask) and not (~focus_present_mask).all():
945
+ attend_all_mask = torch.ones((n, n),
946
+ device=device,
947
+ dtype=torch.bool)
948
+ attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
949
+
950
+ mask = torch.where(
951
+ rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
952
+ rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
953
+ rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
954
+ )
955
+
956
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
957
+
958
+ if self.use_sim_mask:
959
+ sim_mask = torch.tril(
960
+ torch.ones((n, n), device=device, dtype=torch.bool),
961
+ diagonal=0)
962
+ sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max)
963
+
964
+ # numerical stability
965
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
966
+ attn = sim.softmax(dim=-1)
967
+
968
+ # aggregate values
969
+
970
+ out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
971
+ out = rearrange(out, '... h n d -> ... n (h d)')
972
+ out = self.to_out(out)
973
+
974
+ out = rearrange(out, 'b (h w) f c -> b c f h w', h=height)
975
+
976
+ if self.use_image_dataset:
977
+ out = identity + 0 * out
978
+ else:
979
+ out = identity + out
980
+ return out
981
+
982
+
983
+ class TemporalTransformer(nn.Module):
984
+ """
985
+ Transformer block for image-like data.
986
+ First, project the input (aka embedding)
987
+ and reshape to b, t, d.
988
+ Then apply standard transformer action.
989
+ Finally, reshape to image
990
+ """
991
+
992
+ def __init__(self,
993
+ in_channels,
994
+ n_heads,
995
+ d_head,
996
+ depth=1,
997
+ dropout=0.,
998
+ context_dim=None,
999
+ disable_self_attn=False,
1000
+ use_linear=False,
1001
+ use_checkpoint=True,
1002
+ only_self_att=True,
1003
+ multiply_zero=False,
1004
+ is_ctrl=False):
1005
+ super().__init__()
1006
+ self.multiply_zero = multiply_zero
1007
+ self.only_self_att = only_self_att
1008
+ self.use_adaptor = False
1009
+ if self.only_self_att:
1010
+ context_dim = None
1011
+ if not isinstance(context_dim, list):
1012
+ context_dim = [context_dim]
1013
+ self.in_channels = in_channels
1014
+ inner_dim = n_heads * d_head
1015
+ self.norm = torch.nn.GroupNorm(
1016
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
1017
+ if not use_linear:
1018
+ self.proj_in = nn.Conv1d(
1019
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
1020
+ else:
1021
+ self.proj_in = nn.Linear(in_channels, inner_dim)
1022
+ if self.use_adaptor:
1023
+ self.adaptor_in = nn.Linear(frames, frames)
1024
+
1025
+ self.transformer_blocks = nn.ModuleList([
1026
+ BasicTransformerBlock(
1027
+ inner_dim,
1028
+ n_heads,
1029
+ d_head,
1030
+ dropout=dropout,
1031
+ context_dim=context_dim[d],
1032
+ checkpoint=use_checkpoint,
1033
+ local_type='temp',
1034
+ is_ctrl=is_ctrl) for d in range(depth)
1035
+ ])
1036
+ if not use_linear:
1037
+ self.proj_out = zero_module(
1038
+ nn.Conv1d(
1039
+ inner_dim, in_channels, kernel_size=1, stride=1,
1040
+ padding=0))
1041
+ else:
1042
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
1043
+ if self.use_adaptor:
1044
+ self.adaptor_out = nn.Linear(frames, frames)
1045
+ self.use_linear = use_linear
1046
+
1047
+ def forward(self, x, context=None):
1048
+ # note: if no context is given, cross-attention defaults to self-attention
1049
+ if self.only_self_att:
1050
+ context = None
1051
+ if not isinstance(context, list):
1052
+ context = [context]
1053
+ b, _, _, h, w = x.shape
1054
+ x_in = x
1055
+ x = self.norm(x)
1056
+
1057
+ if not self.use_linear:
1058
+ x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
1059
+ x = self.proj_in(x)
1060
+ if self.use_linear:
1061
+ x = rearrange(
1062
+ x, 'b c f h w -> (b h w) f c').contiguous()
1063
+ x = self.proj_in(x)
1064
+ x = rearrange(
1065
+ x, 'bhw f c -> bhw c f').contiguous()
1066
+
1067
+ # print('x shape:', x.shape) # [28800, 512, 32]
1068
+ if self.only_self_att: # no cross-attention
1069
+ x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
1070
+ for i, block in enumerate(self.transformer_blocks):
1071
+ x = block(x, h=h, w=w)
1072
+ # print('x shape:', x.shape) # [43200, 32, 512]
1073
+ x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
1074
+ else:
1075
+ x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
1076
+ for i, block in enumerate(self.transformer_blocks):
1077
+ context[i] = rearrange(
1078
+ context[i], '(b f) l con -> b f l con',
1079
+ f=self.frames).contiguous()
1080
+ # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
1081
+ for j in range(b):
1082
+ context_i_j = repeat(
1083
+ context[i][j],
1084
+ 'f l con -> (f r) l con',
1085
+ r=(h * w) // self.frames,
1086
+ f=self.frames).contiguous()
1087
+ x[j] = block(x[j], context=context_i_j)
1088
+
1089
+ if self.use_linear:
1090
+ x = rearrange(x, 'b hw f c -> (b hw) f c').contiguous()
1091
+ x = self.proj_out(x)
1092
+ x = rearrange(
1093
+ x, '(b h w) f c -> b c f h w', b=b, h=h, w=w).contiguous()
1094
+ if not self.use_linear:
1095
+ # print('x shape:', x.shape) # [2, 21600, 32, 512]
1096
+ x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
1097
+ x = self.proj_out(x)
1098
+ x = rearrange(
1099
+ x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
1100
+
1101
+ if self.multiply_zero:
1102
+ x = 0.0 * x + x_in
1103
+ else:
1104
+ x = x + x_in
1105
+ return x
1106
+
1107
+
1108
+ class TemporalAttentionMultiBlock(nn.Module):
1109
+
1110
+ def __init__(
1111
+ self,
1112
+ dim,
1113
+ heads=4,
1114
+ dim_head=32,
1115
+ rotary_emb=None,
1116
+ use_image_dataset=False,
1117
+ use_sim_mask=False,
1118
+ temporal_attn_times=1,
1119
+ ):
1120
+ super().__init__()
1121
+ self.att_layers = nn.ModuleList([
1122
+ TemporalAttentionBlock(dim, heads, dim_head, rotary_emb,
1123
+ use_image_dataset, use_sim_mask)
1124
+ for _ in range(temporal_attn_times)
1125
+ ])
1126
+
1127
+ def forward(self,
1128
+ x,
1129
+ pos_bias=None,
1130
+ focus_present_mask=None,
1131
+ video_mask=None):
1132
+ for layer in self.att_layers:
1133
+ x = layer(x, pos_bias, focus_present_mask, video_mask)
1134
+ return x
1135
+
1136
+
1137
+ class InitTemporalConvBlock(nn.Module):
1138
+
1139
+ def __init__(self,
1140
+ in_dim,
1141
+ out_dim=None,
1142
+ dropout=0.0,
1143
+ use_image_dataset=False):
1144
+ super(InitTemporalConvBlock, self).__init__()
1145
+ if out_dim is None:
1146
+ out_dim = in_dim
1147
+ self.in_dim = in_dim
1148
+ self.out_dim = out_dim
1149
+ self.use_image_dataset = use_image_dataset
1150
+
1151
+ # conv layers
1152
+ self.conv = nn.Sequential(
1153
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
1154
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
1155
+
1156
+ # zero out the last layer params,so the conv block is identity
1157
+ nn.init.zeros_(self.conv[-1].weight)
1158
+ nn.init.zeros_(self.conv[-1].bias)
1159
+
1160
+ def forward(self, x):
1161
+ identity = x
1162
+ x = self.conv(x)
1163
+ if self.use_image_dataset:
1164
+ x = identity + 0 * x
1165
+ else:
1166
+ x = identity + x
1167
+ return x
1168
+
1169
+
1170
+ class TemporalConvBlock(nn.Module):
1171
+
1172
+ def __init__(self,
1173
+ in_dim,
1174
+ out_dim=None,
1175
+ dropout=0.0,
1176
+ use_image_dataset=False):
1177
+ super(TemporalConvBlock, self).__init__()
1178
+ if out_dim is None:
1179
+ out_dim = in_dim
1180
+ self.in_dim = in_dim
1181
+ self.out_dim = out_dim
1182
+ self.use_image_dataset = use_image_dataset
1183
+
1184
+ # conv layers
1185
+ self.conv1 = nn.Sequential(
1186
+ nn.GroupNorm(32, in_dim), nn.SiLU(),
1187
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
1188
+ self.conv2 = nn.Sequential(
1189
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
1190
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
1191
+
1192
+ # zero out the last layer params,so the conv block is identity
1193
+ nn.init.zeros_(self.conv2[-1].weight)
1194
+ nn.init.zeros_(self.conv2[-1].bias)
1195
+
1196
+ def forward(self, x):
1197
+ identity = x
1198
+ x = self.conv1(x)
1199
+ x = self.conv2(x)
1200
+ if self.use_image_dataset:
1201
+ x = identity + 0 * x
1202
+ else:
1203
+ x = identity + x
1204
+ return x
1205
+
1206
+
1207
+ class TemporalConvBlock_v2(nn.Module):
1208
+
1209
+ def __init__(self,
1210
+ in_dim,
1211
+ out_dim=None,
1212
+ dropout=0.0,
1213
+ use_image_dataset=False):
1214
+ super(TemporalConvBlock_v2, self).__init__()
1215
+ if out_dim is None:
1216
+ out_dim = in_dim
1217
+ self.in_dim = in_dim
1218
+ self.out_dim = out_dim
1219
+ self.use_image_dataset = use_image_dataset
1220
+
1221
+ # conv layers
1222
+ self.conv1 = nn.Sequential(
1223
+ nn.GroupNorm(32, in_dim), nn.SiLU(),
1224
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
1225
+ self.conv2 = nn.Sequential(
1226
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
1227
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
1228
+ self.conv3 = nn.Sequential(
1229
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
1230
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
1231
+ self.conv4 = nn.Sequential(
1232
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
1233
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
1234
+
1235
+ # zero out the last layer params,so the conv block is identity
1236
+ nn.init.zeros_(self.conv4[-1].weight)
1237
+ nn.init.zeros_(self.conv4[-1].bias)
1238
+
1239
+ def forward(self, x, variant_info=None):
1240
+ if variant_info is not None and variant_info.get('type') == 'variant2':
1241
+ # print(x.shape) # torch.Size([1, 320, 32, 90, 160])
1242
+ _, _, f, _, _ = x.shape
1243
+ assert f % 4 == 0, "f must be divisible by 4"
1244
+ x_short = rearrange(x, "b c (n s) h w -> (n b) c s h w", n=4)
1245
+ x_short = self.conv1(x_short)
1246
+ x_short = self.conv2(x_short)
1247
+ x_short = self.conv3(x_short)
1248
+ x_short = self.conv4(x_short)
1249
+ x_short = rearrange(x_short, "(n b) c s h w -> b c (n s) h w", n=4)
1250
+
1251
+ identity = x
1252
+ x = self.conv1(x)
1253
+ x = self.conv2(x)
1254
+ x = self.conv3(x)
1255
+ x = self.conv4(x)
1256
+
1257
+ x = x * (1-variant_info['alpha']) + x_short * variant_info['alpha']
1258
+
1259
+
1260
+ elif variant_info is not None and variant_info.get('type') == 'variant1':
1261
+ identity = x
1262
+ x_long, x_short = x.chunk(2, dim=0)
1263
+
1264
+ x_short = rearrange(x_short, "b c (n s) h w -> (n b) c s h w", n=4)
1265
+ x_short = self.conv1(x_short)
1266
+ x_short = self.conv2(x_short)
1267
+ x_short = self.conv3(x_short)
1268
+ x_short = self.conv4(x_short)
1269
+ x_short = rearrange(x_short, "(n b) c s h w -> b c (n s) h w", n=4)
1270
+
1271
+ x_long = self.conv1(x_long)
1272
+ x_long = self.conv2(x_long)
1273
+ x_long = self.conv3(x_long)
1274
+ x_long = self.conv4(x_long)
1275
+
1276
+ x = torch.cat([x_long, x_short], dim=0)
1277
+
1278
+
1279
+ elif variant_info is None:
1280
+ identity = x
1281
+ x = self.conv1(x)
1282
+ x = self.conv2(x)
1283
+ x = self.conv3(x)
1284
+ x = self.conv4(x)
1285
+
1286
+
1287
+ if self.use_image_dataset:
1288
+ x = identity + 0.0 * x
1289
+ else:
1290
+ x = identity + x
1291
+ return x
1292
+
1293
+
1294
+ class Vid2VidSDUNet(nn.Module):
1295
+
1296
+ def __init__(self,
1297
+ in_dim=4,
1298
+ dim=320,
1299
+ y_dim=1024,
1300
+ context_dim=1024,
1301
+ out_dim=4,
1302
+ dim_mult=[1, 2, 4, 4],
1303
+ num_heads=8,
1304
+ head_dim=64,
1305
+ num_res_blocks=2,
1306
+ attn_scales=[1 / 1, 1 / 2, 1 / 4],
1307
+ use_scale_shift_norm=True,
1308
+ dropout=0.1,
1309
+ temporal_attn_times=1,
1310
+ temporal_attention=True,
1311
+ use_checkpoint=True,
1312
+ use_image_dataset=False,
1313
+ use_fps_condition=False,
1314
+ use_sim_mask=False,
1315
+ training=False,
1316
+ inpainting=True):
1317
+ embed_dim = dim * 4
1318
+ num_heads = num_heads if num_heads else dim // 32
1319
+ super(Vid2VidSDUNet, self).__init__()
1320
+ self.in_dim = in_dim
1321
+ self.dim = dim
1322
+ self.y_dim = y_dim
1323
+ self.context_dim = context_dim
1324
+ self.embed_dim = embed_dim
1325
+ self.out_dim = out_dim
1326
+ self.dim_mult = dim_mult
1327
+ # for temporal attention
1328
+ self.num_heads = num_heads
1329
+ # for spatial attention
1330
+ self.head_dim = head_dim
1331
+ self.num_res_blocks = num_res_blocks
1332
+ self.attn_scales = attn_scales
1333
+ self.use_scale_shift_norm = use_scale_shift_norm
1334
+ self.temporal_attn_times = temporal_attn_times
1335
+ self.temporal_attention = temporal_attention
1336
+ self.use_checkpoint = use_checkpoint
1337
+ self.use_image_dataset = use_image_dataset
1338
+ self.use_fps_condition = use_fps_condition
1339
+ self.use_sim_mask = use_sim_mask
1340
+ self.training = training
1341
+ self.inpainting = inpainting
1342
+
1343
+ use_linear_in_temporal = False
1344
+ transformer_depth = 1
1345
+ disabled_sa = False
1346
+ # params
1347
+ enc_dims = [dim * u for u in [1] + dim_mult]
1348
+ dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
1349
+ shortcut_dims = []
1350
+ scale = 1.0
1351
+
1352
+ # embeddings
1353
+ self.time_embed = nn.Sequential(
1354
+ nn.Linear(dim, embed_dim), nn.SiLU(),
1355
+ nn.Linear(embed_dim, embed_dim))
1356
+
1357
+ if self.use_fps_condition:
1358
+ self.fps_embedding = nn.Sequential(
1359
+ nn.Linear(dim, embed_dim), nn.SiLU(),
1360
+ nn.Linear(embed_dim, embed_dim))
1361
+ nn.init.zeros_(self.fps_embedding[-1].weight)
1362
+ nn.init.zeros_(self.fps_embedding[-1].bias)
1363
+
1364
+ # encoder
1365
+ self.input_blocks = nn.ModuleList()
1366
+ init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
1367
+ # need an initial temporal attention?
1368
+ if temporal_attention:
1369
+ if USE_TEMPORAL_TRANSFORMER:
1370
+ init_block.append(
1371
+ TemporalTransformer(
1372
+ dim,
1373
+ num_heads,
1374
+ head_dim,
1375
+ depth=transformer_depth,
1376
+ context_dim=context_dim,
1377
+ disable_self_attn=disabled_sa,
1378
+ use_linear=use_linear_in_temporal,
1379
+ multiply_zero=use_image_dataset,
1380
+ is_ctrl=True
1381
+ ))
1382
+ else:
1383
+ init_block.append(
1384
+ TemporalAttentionMultiBlock(
1385
+ dim,
1386
+ num_heads,
1387
+ head_dim,
1388
+ rotary_emb=self.rotary_emb,
1389
+ temporal_attn_times=temporal_attn_times,
1390
+ use_image_dataset=use_image_dataset))
1391
+ self.input_blocks.append(init_block)
1392
+ shortcut_dims.append(dim)
1393
+ for i, (in_dim,
1394
+ out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
1395
+ for j in range(num_res_blocks):
1396
+ block = nn.ModuleList([
1397
+ ResBlock(
1398
+ in_dim,
1399
+ embed_dim,
1400
+ dropout,
1401
+ out_channels=out_dim,
1402
+ use_scale_shift_norm=False,
1403
+ use_image_dataset=use_image_dataset,
1404
+ )
1405
+ ])
1406
+ if scale in attn_scales:
1407
+ block.append(
1408
+ SpatialTransformer(
1409
+ out_dim,
1410
+ out_dim // head_dim,
1411
+ head_dim,
1412
+ depth=1,
1413
+ context_dim=self.context_dim,
1414
+ disable_self_attn=False,
1415
+ use_linear=True,
1416
+ is_ctrl=True
1417
+ ))
1418
+ if self.temporal_attention:
1419
+ if USE_TEMPORAL_TRANSFORMER:
1420
+ block.append(
1421
+ TemporalTransformer(
1422
+ out_dim,
1423
+ out_dim // head_dim,
1424
+ head_dim,
1425
+ depth=transformer_depth,
1426
+ context_dim=context_dim,
1427
+ disable_self_attn=disabled_sa,
1428
+ use_linear=use_linear_in_temporal,
1429
+ multiply_zero=use_image_dataset,
1430
+ is_ctrl=True
1431
+ ))
1432
+ else:
1433
+ block.append(
1434
+ TemporalAttentionMultiBlock(
1435
+ out_dim,
1436
+ num_heads,
1437
+ head_dim,
1438
+ rotary_emb=self.rotary_emb,
1439
+ use_image_dataset=use_image_dataset,
1440
+ use_sim_mask=use_sim_mask,
1441
+ temporal_attn_times=temporal_attn_times))
1442
+ in_dim = out_dim
1443
+ self.input_blocks.append(block)
1444
+ shortcut_dims.append(out_dim)
1445
+
1446
+ # downsample
1447
+ if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
1448
+ downsample = Downsample(
1449
+ out_dim, True, dims=2, out_channels=out_dim)
1450
+ shortcut_dims.append(out_dim)
1451
+ scale /= 2.0
1452
+ self.input_blocks.append(downsample)
1453
+
1454
+ self.middle_block = nn.ModuleList([
1455
+ ResBlock(
1456
+ out_dim,
1457
+ embed_dim,
1458
+ dropout,
1459
+ use_scale_shift_norm=False,
1460
+ use_image_dataset=use_image_dataset,
1461
+ ),
1462
+ SpatialTransformer(
1463
+ out_dim,
1464
+ out_dim // head_dim,
1465
+ head_dim,
1466
+ depth=1,
1467
+ context_dim=self.context_dim,
1468
+ disable_self_attn=False,
1469
+ use_linear=True,
1470
+ is_ctrl=True
1471
+ )
1472
+ ])
1473
+
1474
+ if self.temporal_attention:
1475
+ if USE_TEMPORAL_TRANSFORMER:
1476
+ self.middle_block.append(
1477
+ TemporalTransformer(
1478
+ out_dim,
1479
+ out_dim // head_dim,
1480
+ head_dim,
1481
+ depth=transformer_depth,
1482
+ context_dim=context_dim,
1483
+ disable_self_attn=disabled_sa,
1484
+ use_linear=use_linear_in_temporal,
1485
+ multiply_zero=use_image_dataset,
1486
+ is_ctrl=True
1487
+
1488
+ ))
1489
+ else:
1490
+ self.middle_block.append(
1491
+ TemporalAttentionMultiBlock(
1492
+ out_dim,
1493
+ num_heads,
1494
+ head_dim,
1495
+ rotary_emb=self.rotary_emb,
1496
+ use_image_dataset=use_image_dataset,
1497
+ use_sim_mask=use_sim_mask,
1498
+ temporal_attn_times=temporal_attn_times))
1499
+
1500
+ self.middle_block.append(
1501
+ ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
1502
+
1503
+ # decoder
1504
+ self.output_blocks = nn.ModuleList()
1505
+ for i, (in_dim,
1506
+ out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
1507
+ for j in range(num_res_blocks + 1):
1508
+ block = nn.ModuleList([
1509
+ ResBlock(
1510
+ in_dim + shortcut_dims.pop(),
1511
+ embed_dim,
1512
+ dropout,
1513
+ out_dim,
1514
+ use_scale_shift_norm=False,
1515
+ use_image_dataset=use_image_dataset,
1516
+ )
1517
+ ])
1518
+ if scale in attn_scales:
1519
+ block.append(
1520
+ SpatialTransformer(
1521
+ out_dim,
1522
+ out_dim // head_dim,
1523
+ head_dim,
1524
+ depth=1,
1525
+ context_dim=1024,
1526
+ disable_self_attn=False,
1527
+ use_linear=True,
1528
+ is_ctrl=True))
1529
+ if self.temporal_attention:
1530
+ if USE_TEMPORAL_TRANSFORMER:
1531
+ block.append(
1532
+ TemporalTransformer(
1533
+ out_dim,
1534
+ out_dim // head_dim,
1535
+ head_dim,
1536
+ depth=transformer_depth,
1537
+ context_dim=context_dim,
1538
+ disable_self_attn=disabled_sa,
1539
+ use_linear=use_linear_in_temporal,
1540
+ multiply_zero=use_image_dataset,
1541
+ is_ctrl=True))
1542
+ else:
1543
+ block.append(
1544
+ TemporalAttentionMultiBlock(
1545
+ out_dim,
1546
+ num_heads,
1547
+ head_dim,
1548
+ rotary_emb=self.rotary_emb,
1549
+ use_image_dataset=use_image_dataset,
1550
+ use_sim_mask=use_sim_mask,
1551
+ temporal_attn_times=temporal_attn_times))
1552
+ in_dim = out_dim
1553
+
1554
+ # upsample
1555
+ if i != len(dim_mult) - 1 and j == num_res_blocks:
1556
+ upsample = Upsample(
1557
+ out_dim, True, dims=2.0, out_channels=out_dim)
1558
+ scale *= 2.0
1559
+ block.append(upsample)
1560
+ self.output_blocks.append(block)
1561
+
1562
+ # head
1563
+ self.out = nn.Sequential(
1564
+ nn.GroupNorm(32, out_dim), nn.SiLU(),
1565
+ nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
1566
+
1567
+ # zero out the last layer params
1568
+ nn.init.zeros_(self.out[-1].weight)
1569
+
1570
+ def forward(self,
1571
+ x,
1572
+ t,
1573
+ y,
1574
+ x_lr=None,
1575
+ fps=None,
1576
+ video_mask=None,
1577
+ focus_present_mask=None,
1578
+ prob_focus_present=0.,
1579
+ mask_last_frame_num=0):
1580
+
1581
+ batch, c, f, h, w = x.shape
1582
+ device = x.device
1583
+ self.batch = batch
1584
+
1585
+ # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
1586
+ if mask_last_frame_num > 0:
1587
+ focus_present_mask = None
1588
+ video_mask[-mask_last_frame_num:] = False
1589
+ else:
1590
+ focus_present_mask = default(
1591
+ focus_present_mask, lambda: prob_mask_like(
1592
+ (batch, ), prob_focus_present, device=device))
1593
+
1594
+ if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
1595
+ time_rel_pos_bias = self.time_rel_pos_bias(
1596
+ x.shape[2], device=x.device)
1597
+ else:
1598
+ time_rel_pos_bias = None
1599
+
1600
+ # embeddings
1601
+ e = self.time_embed(sinusoidal_embedding(t, self.dim))
1602
+ context = y
1603
+
1604
+ # repeat f times for spatial e and context
1605
+ e = e.repeat_interleave(repeats=f, dim=0)
1606
+ context = context.repeat_interleave(repeats=f, dim=0)
1607
+
1608
+ # always in shape (b f) c h w, except for temporal layer
1609
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
1610
+ # encoder
1611
+ xs = []
1612
+ for ind, block in enumerate(self.input_blocks):
1613
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias,
1614
+ focus_present_mask, video_mask)
1615
+ xs.append(x)
1616
+
1617
+ # middle
1618
+ for block in self.middle_block:
1619
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias,
1620
+ focus_present_mask, video_mask)
1621
+
1622
+ # decoder
1623
+ for block in self.output_blocks:
1624
+ x = torch.cat([x, xs.pop()], dim=1)
1625
+ x = self._forward_single(
1626
+ block,
1627
+ x,
1628
+ e,
1629
+ context,
1630
+ time_rel_pos_bias,
1631
+ focus_present_mask,
1632
+ video_mask,
1633
+ reference=xs[-1] if len(xs) > 0 else None)
1634
+
1635
+ # head
1636
+ x = self.out(x)
1637
+
1638
+ # reshape back to (b c f h w)
1639
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
1640
+ return x
1641
+
1642
+ def _forward_single(self,
1643
+ module,
1644
+ x,
1645
+ e,
1646
+ context,
1647
+ time_rel_pos_bias,
1648
+ focus_present_mask,
1649
+ video_mask,
1650
+ reference=None):
1651
+ if isinstance(module, ResidualBlock):
1652
+ module = checkpoint_wrapper(
1653
+ module) if self.use_checkpoint else module
1654
+ x = x.contiguous()
1655
+ x = module(x, e, reference)
1656
+ elif isinstance(module, ResBlock):
1657
+ module = checkpoint_wrapper(
1658
+ module) if self.use_checkpoint else module
1659
+ x = x.contiguous()
1660
+ x = module(x, e, self.batch)
1661
+ elif isinstance(module, SpatialTransformer):
1662
+ module = checkpoint_wrapper(
1663
+ module) if self.use_checkpoint else module
1664
+ x = module(x, context)
1665
+ elif isinstance(module, TemporalTransformer):
1666
+ module = checkpoint_wrapper(
1667
+ module) if self.use_checkpoint else module
1668
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1669
+ x = module(x, context)
1670
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
1671
+ elif isinstance(module, CrossAttention):
1672
+ module = checkpoint_wrapper(
1673
+ module) if self.use_checkpoint else module
1674
+ x = module(x, context)
1675
+ elif isinstance(module, MemoryEfficientCrossAttention):
1676
+ module = checkpoint_wrapper(
1677
+ module) if self.use_checkpoint else module
1678
+ x = module(x, context)
1679
+ elif isinstance(module, BasicTransformerBlock):
1680
+ module = checkpoint_wrapper(
1681
+ module) if self.use_checkpoint else module
1682
+ x = module(x, context)
1683
+ elif isinstance(module, FeedForward):
1684
+ x = module(x, context)
1685
+ elif isinstance(module, Upsample):
1686
+ x = module(x)
1687
+ elif isinstance(module, Downsample):
1688
+ x = module(x)
1689
+ elif isinstance(module, Resample):
1690
+ x = module(x, reference)
1691
+ elif isinstance(module, TemporalAttentionBlock):
1692
+ module = checkpoint_wrapper(
1693
+ module) if self.use_checkpoint else module
1694
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1695
+ x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
1696
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
1697
+ elif isinstance(module, TemporalAttentionMultiBlock):
1698
+ module = checkpoint_wrapper(
1699
+ module) if self.use_checkpoint else module
1700
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1701
+ x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
1702
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
1703
+ elif isinstance(module, InitTemporalConvBlock):
1704
+ module = checkpoint_wrapper(
1705
+ module) if self.use_checkpoint else module
1706
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1707
+ x = module(x)
1708
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
1709
+ elif isinstance(module, TemporalConvBlock):
1710
+ module = checkpoint_wrapper(
1711
+ module) if self.use_checkpoint else module
1712
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1713
+ x = module(x)
1714
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
1715
+ elif isinstance(module, nn.ModuleList):
1716
+ for block in module:
1717
+ x = self._forward_single(block, x, e, context,
1718
+ time_rel_pos_bias, focus_present_mask,
1719
+ video_mask, reference)
1720
+ else:
1721
+ x = module(x)
1722
+ return x
1723
+
1724
+
1725
+ class ControlledV2VUNet(Vid2VidSDUNet):
1726
+ def __init__(self):
1727
+ super(ControlledV2VUNet, self).__init__()
1728
+ self.VideoControlNet = VideoControlNet()
1729
+
1730
+ def forward(self,
1731
+ x,
1732
+ t,
1733
+ y,
1734
+ hint=None,
1735
+ variant_info=None,
1736
+ hint_chunk=None,
1737
+ t_hint=None,
1738
+ s_cond=None,
1739
+ mask_cond=None,
1740
+ x_lr=None,
1741
+ fps=None,
1742
+ mask=None,
1743
+ video_mask=None,
1744
+ focus_present_mask=None,
1745
+ prob_focus_present=0.,
1746
+ mask_last_frame_num=0,
1747
+ ):
1748
+
1749
+ batch, _, f, _, _= x.shape
1750
+ device = x.device
1751
+ self.batch = batch
1752
+
1753
+ # Process text (new added for t5 encoder)
1754
+ # y = self.VideoControlNet.y_embedder(y, self.training).squeeze(1) # [1, 1, 120, 4096] -> [B, 1, 120, 1024].squeeze(1) -> [B, 120, 1024]
1755
+
1756
+ if hint_chunk is not None:
1757
+ hint = hint_chunk
1758
+
1759
+ control = self.VideoControlNet(x, t, y, hint=hint, t_hint=t_hint, \
1760
+ mask_cond=mask_cond, s_cond=s_cond, \
1761
+ variant_info=variant_info)
1762
+
1763
+ # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
1764
+ if mask_last_frame_num > 0:
1765
+ focus_present_mask = None
1766
+ video_mask[-mask_last_frame_num:] = False
1767
+ else:
1768
+ focus_present_mask = default(
1769
+ focus_present_mask, lambda: prob_mask_like(
1770
+ (batch, ), prob_focus_present, device=device))
1771
+
1772
+ if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
1773
+ time_rel_pos_bias = self.time_rel_pos_bias(
1774
+ x.shape[2], device=x.device)
1775
+ else:
1776
+ time_rel_pos_bias = None
1777
+
1778
+ e = self.time_embed(sinusoidal_embedding(t, self.dim))
1779
+ e = e.repeat_interleave(repeats=f, dim=0)
1780
+
1781
+ # context = y
1782
+ context = y.repeat_interleave(repeats=f, dim=0)
1783
+
1784
+ # always in shape (b f) c h w, except for temporal layer
1785
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
1786
+ # encoder
1787
+ xs = []
1788
+ for block in self.input_blocks:
1789
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias,
1790
+ focus_present_mask, video_mask, variant_info=variant_info)
1791
+ xs.append(x)
1792
+ # middle
1793
+ for block in self.middle_block:
1794
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias,
1795
+ focus_present_mask, video_mask, variant_info=variant_info)
1796
+
1797
+ if control is not None:
1798
+ x = control.pop() + x
1799
+
1800
+ # decoder
1801
+ for block in self.output_blocks:
1802
+ if control is None:
1803
+ x = torch.cat([x, xs.pop()], dim=1)
1804
+ else:
1805
+ x = torch.cat([x, xs.pop() + control.pop()], dim=1)
1806
+ x = self._forward_single(
1807
+ block,
1808
+ x,
1809
+ e,
1810
+ context,
1811
+ time_rel_pos_bias,
1812
+ focus_present_mask,
1813
+ video_mask,
1814
+ reference=xs[-1] if len(xs) > 0 else None,
1815
+ variant_info=variant_info)
1816
+
1817
+ # head
1818
+ x = self.out(x)
1819
+
1820
+ # reshape back to (b c f h w)
1821
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
1822
+ return x
1823
+
1824
+ def _forward_single(self,
1825
+ module,
1826
+ x,
1827
+ e,
1828
+ context,
1829
+ time_rel_pos_bias,
1830
+ focus_present_mask,
1831
+ video_mask,
1832
+ reference=None,
1833
+ variant_info=None):
1834
+ variant_info = None # For Debug
1835
+ if isinstance(module, ResidualBlock):
1836
+ module = checkpoint_wrapper(
1837
+ module) if self.use_checkpoint else module
1838
+ x = x.contiguous()
1839
+ x = module(x, e, reference)
1840
+ elif isinstance(module, ResBlock):
1841
+ module = checkpoint_wrapper(
1842
+ module) if self.use_checkpoint else module
1843
+ x = x.contiguous()
1844
+ x = module(x, e, self.batch, variant_info)
1845
+ elif isinstance(module, SpatialTransformer):
1846
+ module = checkpoint_wrapper(
1847
+ module) if self.use_checkpoint else module
1848
+ x = module(x, context)
1849
+ elif isinstance(module, TemporalTransformer):
1850
+ module = checkpoint_wrapper(
1851
+ module) if self.use_checkpoint else module
1852
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1853
+ x = module(x, context)
1854
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
1855
+ elif isinstance(module, CrossAttention):
1856
+ module = checkpoint_wrapper(
1857
+ module) if self.use_checkpoint else module
1858
+ x = module(x, context)
1859
+ elif isinstance(module, MemoryEfficientCrossAttention):
1860
+ module = checkpoint_wrapper(
1861
+ module) if self.use_checkpoint else module
1862
+ x = module(x, context)
1863
+ elif isinstance(module, BasicTransformerBlock):
1864
+ module = checkpoint_wrapper(
1865
+ module) if self.use_checkpoint else module
1866
+ x = module(x, context)
1867
+ elif isinstance(module, FeedForward):
1868
+ x = module(x, context)
1869
+ elif isinstance(module, Upsample):
1870
+ x = module(x)
1871
+ elif isinstance(module, Downsample):
1872
+ x = module(x)
1873
+ elif isinstance(module, Resample):
1874
+ x = module(x, reference)
1875
+ elif isinstance(module, TemporalAttentionBlock):
1876
+ module = checkpoint_wrapper(
1877
+ module) if self.use_checkpoint else module
1878
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1879
+ x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
1880
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
1881
+ elif isinstance(module, TemporalAttentionMultiBlock):
1882
+ module = checkpoint_wrapper(
1883
+ module) if self.use_checkpoint else module
1884
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1885
+ x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
1886
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
1887
+ elif isinstance(module, InitTemporalConvBlock):
1888
+ module = checkpoint_wrapper(
1889
+ module) if self.use_checkpoint else module
1890
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1891
+ x = module(x)
1892
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
1893
+ elif isinstance(module, TemporalConvBlock):
1894
+ module = checkpoint_wrapper(
1895
+ module) if self.use_checkpoint else module
1896
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
1897
+ x = module(x)
1898
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
1899
+ elif isinstance(module, nn.ModuleList):
1900
+ for block in module:
1901
+ x = self._forward_single(block, x, e, context,
1902
+ time_rel_pos_bias, focus_present_mask,
1903
+ video_mask, reference, variant_info)
1904
+ else:
1905
+ x = module(x)
1906
+ return x
1907
+
1908
+
1909
+ class VideoControlNet(nn.Module):
1910
+
1911
+ def __init__(self,
1912
+ in_dim=4,
1913
+ dim=320,
1914
+ y_dim=1024,
1915
+ context_dim=1024,
1916
+ out_dim=4,
1917
+ dim_mult=[1, 2, 4, 4],
1918
+ num_heads=8,
1919
+ head_dim=64,
1920
+ num_res_blocks=2,
1921
+ attn_scales=[1 / 1, 1 / 2, 1 / 4],
1922
+ use_scale_shift_norm=True,
1923
+ dropout=0.1,
1924
+ temporal_attn_times=1,
1925
+ temporal_attention=True,
1926
+ use_checkpoint=True,
1927
+ use_image_dataset=False,
1928
+ use_fps_condition=False,
1929
+ use_sim_mask=False,
1930
+ training=False,
1931
+ inpainting=True):
1932
+ embed_dim = dim * 4
1933
+ num_heads = num_heads if num_heads else dim // 32
1934
+ super(VideoControlNet, self).__init__()
1935
+ self.in_dim = in_dim
1936
+ self.dim = dim
1937
+ self.y_dim = y_dim
1938
+ self.context_dim = context_dim
1939
+ self.embed_dim = embed_dim
1940
+ self.out_dim = out_dim
1941
+ self.dim_mult = dim_mult
1942
+ # for temporal attention
1943
+ self.num_heads = num_heads
1944
+ # for spatial attention
1945
+ self.head_dim = head_dim
1946
+ self.num_res_blocks = num_res_blocks
1947
+ self.attn_scales = attn_scales
1948
+ self.use_scale_shift_norm = use_scale_shift_norm
1949
+ self.temporal_attn_times = temporal_attn_times
1950
+ self.temporal_attention = temporal_attention
1951
+ self.use_checkpoint = use_checkpoint
1952
+ self.use_image_dataset = use_image_dataset
1953
+ self.use_fps_condition = use_fps_condition
1954
+ self.use_sim_mask = use_sim_mask
1955
+ self.training = training
1956
+ self.inpainting = inpainting
1957
+
1958
+ use_linear_in_temporal = False
1959
+ transformer_depth = 1
1960
+ disabled_sa = False
1961
+ # params
1962
+ enc_dims = [dim * u for u in [1] + dim_mult]
1963
+ dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
1964
+ shortcut_dims = []
1965
+ scale = 1.0
1966
+
1967
+ # CaptionEmbedder (new add)
1968
+ # approx_gelu = lambda: nn.GELU(approximate="tanh")
1969
+ # self.y_embedder = CaptionEmbedder(
1970
+ # in_channels=4096,
1971
+ # hidden_size=1024,
1972
+ # uncond_prob=0.1,
1973
+ # act_layer=approx_gelu,
1974
+ # token_num=120,
1975
+ # )
1976
+
1977
+ # embeddings
1978
+ self.time_embed = nn.Sequential(
1979
+ nn.Linear(dim, embed_dim), nn.SiLU(),
1980
+ nn.Linear(embed_dim, embed_dim))
1981
+
1982
+ # self.hint_time_zero_linear = zero_module(nn.Linear(embed_dim, embed_dim))
1983
+
1984
+ # scale prompt
1985
+ # self.scale_cond = nn.Sequential(
1986
+ # nn.Linear(dim, embed_dim), nn.SiLU(),
1987
+ # zero_module(nn.Linear(embed_dim, embed_dim)))
1988
+
1989
+ if self.use_fps_condition:
1990
+ self.fps_embedding = nn.Sequential(
1991
+ nn.Linear(dim, embed_dim), nn.SiLU(),
1992
+ nn.Linear(embed_dim, embed_dim))
1993
+ nn.init.zeros_(self.fps_embedding[-1].weight)
1994
+ nn.init.zeros_(self.fps_embedding[-1].bias)
1995
+
1996
+ # encoder
1997
+ self.input_blocks = nn.ModuleList()
1998
+ init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
1999
+ # need an initial temporal attention?
2000
+ if temporal_attention:
2001
+ if USE_TEMPORAL_TRANSFORMER:
2002
+ init_block.append(
2003
+ TemporalTransformer(
2004
+ dim,
2005
+ num_heads,
2006
+ head_dim,
2007
+ depth=transformer_depth,
2008
+ context_dim=context_dim,
2009
+ disable_self_attn=disabled_sa,
2010
+ use_linear=use_linear_in_temporal,
2011
+ multiply_zero=use_image_dataset,
2012
+ is_ctrl=True,))
2013
+ else:
2014
+ init_block.append(
2015
+ TemporalAttentionMultiBlock(
2016
+ dim,
2017
+ num_heads,
2018
+ head_dim,
2019
+ rotary_emb=self.rotary_emb,
2020
+ temporal_attn_times=temporal_attn_times,
2021
+ use_image_dataset=use_image_dataset))
2022
+ self.input_blocks.append(init_block)
2023
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(dim)])
2024
+ shortcut_dims.append(dim)
2025
+ for i, (in_dim,
2026
+ out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
2027
+ for j in range(num_res_blocks):
2028
+ block = nn.ModuleList([
2029
+ ResBlock(
2030
+ in_dim,
2031
+ embed_dim,
2032
+ dropout,
2033
+ out_channels=out_dim,
2034
+ use_scale_shift_norm=False,
2035
+ use_image_dataset=use_image_dataset,
2036
+ )
2037
+ ])
2038
+ if scale in attn_scales:
2039
+ block.append(
2040
+ SpatialTransformer(
2041
+ out_dim,
2042
+ out_dim // head_dim,
2043
+ head_dim,
2044
+ depth=1,
2045
+ context_dim=self.context_dim,
2046
+ disable_self_attn=False,
2047
+ use_linear=True,
2048
+ is_ctrl=True))
2049
+ if self.temporal_attention:
2050
+ if USE_TEMPORAL_TRANSFORMER:
2051
+ block.append(
2052
+ TemporalTransformer(
2053
+ out_dim,
2054
+ out_dim // head_dim,
2055
+ head_dim,
2056
+ depth=transformer_depth,
2057
+ context_dim=context_dim,
2058
+ disable_self_attn=disabled_sa,
2059
+ use_linear=use_linear_in_temporal,
2060
+ multiply_zero=use_image_dataset,
2061
+ is_ctrl=True,))
2062
+ else:
2063
+ block.append(
2064
+ TemporalAttentionMultiBlock(
2065
+ out_dim,
2066
+ num_heads,
2067
+ head_dim,
2068
+ rotary_emb=self.rotary_emb,
2069
+ use_image_dataset=use_image_dataset,
2070
+ use_sim_mask=use_sim_mask,
2071
+ temporal_attn_times=temporal_attn_times))
2072
+ in_dim = out_dim
2073
+ self.input_blocks.append(block)
2074
+ self.zero_convs.append(self.make_zero_conv(out_dim))
2075
+ shortcut_dims.append(out_dim)
2076
+
2077
+ # downsample
2078
+ if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
2079
+ downsample = Downsample(
2080
+ out_dim, True, dims=2, out_channels=out_dim)
2081
+ shortcut_dims.append(out_dim)
2082
+ scale /= 2.0
2083
+ self.input_blocks.append(downsample)
2084
+ self.zero_convs.append(self.make_zero_conv(out_dim))
2085
+
2086
+ self.middle_block = nn.ModuleList([
2087
+ ResBlock(
2088
+ out_dim,
2089
+ embed_dim,
2090
+ dropout,
2091
+ use_scale_shift_norm=False,
2092
+ use_image_dataset=use_image_dataset,
2093
+ ),
2094
+ SpatialTransformer(
2095
+ out_dim,
2096
+ out_dim // head_dim,
2097
+ head_dim,
2098
+ depth=1,
2099
+ context_dim=self.context_dim,
2100
+ disable_self_attn=False,
2101
+ use_linear=True,
2102
+ is_ctrl=True)
2103
+ ])
2104
+
2105
+ if self.temporal_attention:
2106
+ if USE_TEMPORAL_TRANSFORMER:
2107
+ self.middle_block.append(
2108
+ TemporalTransformer(
2109
+ out_dim,
2110
+ out_dim // head_dim,
2111
+ head_dim,
2112
+ depth=transformer_depth,
2113
+ context_dim=context_dim,
2114
+ disable_self_attn=disabled_sa,
2115
+ use_linear=use_linear_in_temporal,
2116
+ multiply_zero=use_image_dataset,
2117
+ is_ctrl=True,
2118
+ ))
2119
+ else:
2120
+ self.middle_block.append(
2121
+ TemporalAttentionMultiBlock(
2122
+ out_dim,
2123
+ num_heads,
2124
+ head_dim,
2125
+ rotary_emb=self.rotary_emb,
2126
+ use_image_dataset=use_image_dataset,
2127
+ use_sim_mask=use_sim_mask,
2128
+ temporal_attn_times=temporal_attn_times))
2129
+
2130
+ self.middle_block.append(
2131
+ ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
2132
+
2133
+ self.middle_block_out = self.make_zero_conv(embed_dim)
2134
+
2135
+ '''
2136
+ add prompt
2137
+ '''
2138
+ add_dim = 320
2139
+ self.add_dim = add_dim
2140
+
2141
+ self.input_hint_block = zero_module(nn.Conv2d(4, add_dim, 3, padding=1))
2142
+
2143
+ def make_zero_conv(self, in_channels, out_channels=None):
2144
+ out_channels = in_channels if out_channels is None else out_channels
2145
+ return TimestepEmbedSequential(zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)))
2146
+
2147
+ def forward(self,
2148
+ x,
2149
+ t,
2150
+ y,
2151
+ s_cond=None,
2152
+ hint=None,
2153
+ variant_info=None,
2154
+ t_hint=None,
2155
+ mask_cond=None,
2156
+ fps=None,
2157
+ video_mask=None,
2158
+ focus_present_mask=None,
2159
+ prob_focus_present=0.,
2160
+ mask_last_frame_num=0):
2161
+
2162
+ batch, _, f, _, _ = x.shape
2163
+ device = x.device
2164
+ self.batch = batch
2165
+
2166
+ # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
2167
+ if mask_last_frame_num > 0:
2168
+ focus_present_mask = None
2169
+ video_mask[-mask_last_frame_num:] = False
2170
+ else:
2171
+ focus_present_mask = default(
2172
+ focus_present_mask, lambda: prob_mask_like(
2173
+ (batch, ), prob_focus_present, device=device))
2174
+
2175
+ if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
2176
+ time_rel_pos_bias = self.time_rel_pos_bias(
2177
+ x.shape[2], device=x.device)
2178
+ else:
2179
+ time_rel_pos_bias = None
2180
+
2181
+ if hint is not None:
2182
+ # add = x.new_zeros(batch, self.add_dim, f, h, w)
2183
+ hint = rearrange(hint, 'b c f h w -> (b f) c h w')
2184
+ hint = self.input_hint_block(hint)
2185
+ # hint = rearrange(hint, '(b f) c h w -> b c f h w', b = batch)
2186
+
2187
+ e = self.time_embed(sinusoidal_embedding(t, self.dim))
2188
+ e = e.repeat_interleave(repeats=f, dim=0)
2189
+
2190
+ context = y.repeat_interleave(repeats=f, dim=0)
2191
+
2192
+ # always in shape (b f) c h w, except for temporal layer
2193
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
2194
+ # print('before x shape:', x.shape) [64, 320, 90, 160]
2195
+ # print('hint shape:', hint.shape) [32, 320, 90, 160]
2196
+
2197
+ # encoder
2198
+ xs = []
2199
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
2200
+ if hint is not None:
2201
+ for block in module:
2202
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias,
2203
+ focus_present_mask, video_mask, variant_info=variant_info)
2204
+ if not isinstance(block, TemporalTransformer):
2205
+ if hint is not None:
2206
+ x += hint
2207
+ hint = None
2208
+ else:
2209
+ x = self._forward_single(module, x, e, context, time_rel_pos_bias,
2210
+ focus_present_mask, video_mask, variant_info=variant_info)
2211
+ xs.append(zero_conv(x, e, context))
2212
+
2213
+ # middle
2214
+ for block in self.middle_block:
2215
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias,
2216
+ focus_present_mask, video_mask, variant_info=variant_info)
2217
+ xs.append(self.middle_block_out(x, e, context))
2218
+
2219
+ return xs
2220
+
2221
+ def _forward_single(self,
2222
+ module,
2223
+ x,
2224
+ e,
2225
+ context,
2226
+ time_rel_pos_bias,
2227
+ focus_present_mask,
2228
+ video_mask,
2229
+ reference=None,
2230
+ variant_info=None,):
2231
+ # variant_info = None # For Debug
2232
+ if isinstance(module, ResidualBlock):
2233
+ module = checkpoint_wrapper(
2234
+ module) if self.use_checkpoint else module
2235
+ x = x.contiguous()
2236
+ x = module(x, e, reference)
2237
+ elif isinstance(module, ResBlock):
2238
+ module = checkpoint_wrapper(
2239
+ module) if self.use_checkpoint else module
2240
+ x = x.contiguous()
2241
+ x = module(x, e, self.batch, variant_info)
2242
+ elif isinstance(module, SpatialTransformer):
2243
+ module = checkpoint_wrapper(
2244
+ module) if self.use_checkpoint else module
2245
+ x = module(x, context)
2246
+ elif isinstance(module, TemporalTransformer):
2247
+ module = checkpoint_wrapper(
2248
+ module) if self.use_checkpoint else module
2249
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
2250
+ # print("x shape:", x.shape) # [2, 320, 32, 90, 160]
2251
+ x = module(x, context)
2252
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
2253
+ elif isinstance(module, CrossAttention):
2254
+ module = checkpoint_wrapper(
2255
+ module) if self.use_checkpoint else module
2256
+ x = module(x, context)
2257
+ elif isinstance(module, MemoryEfficientCrossAttention):
2258
+ module = checkpoint_wrapper(
2259
+ module) if self.use_checkpoint else module
2260
+ x = module(x, context)
2261
+ elif isinstance(module, BasicTransformerBlock):
2262
+ module = checkpoint_wrapper(
2263
+ module) if self.use_checkpoint else module
2264
+ x = module(x, context)
2265
+ elif isinstance(module, FeedForward):
2266
+ x = module(x, context)
2267
+ elif isinstance(module, Upsample):
2268
+ x = module(x)
2269
+ elif isinstance(module, Downsample):
2270
+ x = module(x)
2271
+ elif isinstance(module, Resample):
2272
+ x = module(x, reference)
2273
+ elif isinstance(module, TemporalAttentionBlock):
2274
+ module = checkpoint_wrapper(
2275
+ module) if self.use_checkpoint else module
2276
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
2277
+ x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
2278
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
2279
+ elif isinstance(module, TemporalAttentionMultiBlock):
2280
+ module = checkpoint_wrapper(
2281
+ module) if self.use_checkpoint else module
2282
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
2283
+ x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
2284
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
2285
+ elif isinstance(module, InitTemporalConvBlock):
2286
+ module = checkpoint_wrapper(
2287
+ module) if self.use_checkpoint else module
2288
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
2289
+ x = module(x)
2290
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
2291
+ elif isinstance(module, TemporalConvBlock):
2292
+ module = checkpoint_wrapper(
2293
+ module) if self.use_checkpoint else module
2294
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
2295
+ x = module(x)
2296
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
2297
+ elif isinstance(module, nn.ModuleList):
2298
+ for block in module:
2299
+ x = self._forward_single(block, x, e, context,
2300
+ time_rel_pos_bias, focus_present_mask,
2301
+ video_mask, reference, variant_info)
2302
+ else:
2303
+ x = module(x)
2304
+ return x
2305
+
2306
+
2307
+ class TimestepBlock(nn.Module):
2308
+ """
2309
+ Any module where forward() takes timestep embeddings as a second argument.
2310
+ """
2311
+
2312
+ @abstractmethod
2313
+ def forward(self, x, emb):
2314
+ """
2315
+ Apply the module to `x` given `emb` timestep embeddings.
2316
+ """
2317
+
2318
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
2319
+ """
2320
+ A sequential module that passes timestep embeddings to the children that
2321
+ support it as an extra input.
2322
+ """
2323
+
2324
+ def forward(self, x, emb, context=None):
2325
+ for layer in self:
2326
+ if isinstance(layer, TimestepBlock):
2327
+ x = layer(x, emb)
2328
+ elif isinstance(layer, SpatialTransformer):
2329
+ x = layer(x, context)
2330
+ else:
2331
+ x = layer(x)
2332
+ return x
video_to_video/utils/__init__.py ADDED
File without changes
video_to_video/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (158 Bytes). View file
 
video_to_video/utils/__pycache__/config.cpython-39.pyc ADDED
Binary file (3.43 kB). View file
 
video_to_video/utils/__pycache__/logger.cpython-39.pyc ADDED
Binary file (2.14 kB). View file
 
video_to_video/utils/__pycache__/seed.cpython-39.pyc ADDED
Binary file (466 Bytes). View file
 
video_to_video/utils/config.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import logging
4
+ import os
5
+ import os.path as osp
6
+ from datetime import datetime
7
+
8
+ import torch
9
+ from easydict import EasyDict
10
+
11
+ cfg = EasyDict(__name__='Config: VideoLDM Decoder')
12
+
13
+ # ---------------------------work dir--------------------------
14
+ cfg.work_dir = 'workspace/'
15
+
16
+ # ---------------------------Global Variable-----------------------------------
17
+ cfg.resolution = [448, 256]
18
+ cfg.max_frames = 32
19
+ # -----------------------------------------------------------------------------
20
+
21
+ # ---------------------------Dataset Parameter---------------------------------
22
+ cfg.mean = [0.5, 0.5, 0.5]
23
+ cfg.std = [0.5, 0.5, 0.5]
24
+ cfg.max_words = 1000
25
+
26
+ # PlaceHolder
27
+ cfg.vit_out_dim = 1024
28
+ cfg.vit_resolution = [224, 224]
29
+ cfg.depth_clamp = 10.0
30
+ cfg.misc_size = 384
31
+ cfg.depth_std = 20.0
32
+
33
+ cfg.frame_lens = 32
34
+ cfg.sample_fps = 8
35
+
36
+ cfg.batch_sizes = 1
37
+ # -----------------------------------------------------------------------------
38
+
39
+ # ---------------------------Mode Parameters-----------------------------------
40
+ # Diffusion
41
+ cfg.schedule = 'cosine'
42
+ cfg.num_timesteps = 1000
43
+ cfg.mean_type = 'v'
44
+ cfg.var_type = 'fixed_small'
45
+ cfg.loss_type = 'mse'
46
+ cfg.ddim_timesteps = 50
47
+ cfg.ddim_eta = 0.0
48
+ cfg.clamp = 1.0
49
+ cfg.share_noise = False
50
+ cfg.use_div_loss = False
51
+ cfg.noise_strength = 0.1
52
+
53
+ # classifier-free guidance
54
+ cfg.p_zero = 0.1
55
+ cfg.guide_scale = 3.0
56
+
57
+ # clip vision encoder
58
+ cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
59
+ cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
60
+
61
+ # Model
62
+ cfg.scale_factor = 0.18215
63
+ cfg.use_fp16 = True
64
+ cfg.temporal_attention = True
65
+ cfg.decoder_bs = 8
66
+
67
+ cfg.UNet = {
68
+ 'type': 'Vid2VidSDUNet',
69
+ 'in_dim': 4,
70
+ 'dim': 320,
71
+ 'y_dim': cfg.vit_out_dim,
72
+ 'context_dim': 1024,
73
+ 'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
74
+ 'dim_mult': [1, 2, 4, 4],
75
+ 'num_heads': 8,
76
+ 'head_dim': 64,
77
+ 'num_res_blocks': 2,
78
+ 'attn_scales': [1 / 1, 1 / 2, 1 / 4],
79
+ 'dropout': 0.1,
80
+ 'temporal_attention': cfg.temporal_attention,
81
+ 'temporal_attn_times': 1,
82
+ 'use_checkpoint': False,
83
+ 'use_fps_condition': False,
84
+ 'use_sim_mask': False,
85
+ 'num_tokens': 4,
86
+ 'default_fps': 8,
87
+ 'input_dim': 1024
88
+ }
89
+
90
+ cfg.guidances = []
91
+
92
+ # auotoencoder from stabel diffusion
93
+ cfg.auto_encoder = {
94
+ 'type': 'AutoencoderKL',
95
+ 'ddconfig': {
96
+ 'double_z': True,
97
+ 'z_channels': 4,
98
+ 'resolution': 256,
99
+ 'in_channels': 3,
100
+ 'out_ch': 3,
101
+ 'ch': 128,
102
+ 'ch_mult': [1, 2, 4, 4],
103
+ 'num_res_blocks': 2,
104
+ 'attn_resolutions': [],
105
+ 'dropout': 0.0
106
+ },
107
+ 'embed_dim': 4,
108
+ 'pretrained': 'models/v2-1_512-ema-pruned.ckpt'
109
+ }
110
+ # clip embedder
111
+ cfg.embedder = {
112
+ 'type': 'FrozenOpenCLIPEmbedder',
113
+ 'layer': 'penultimate',
114
+ 'vit_resolution': [224, 224],
115
+ 'pretrained': 'open_clip_pytorch_model.bin'
116
+ }
117
+ # -----------------------------------------------------------------------------
118
+
119
+ # ---------------------------Training Settings---------------------------------
120
+ # training and optimizer
121
+ cfg.ema_decay = 0.9999
122
+ cfg.num_steps = 600000
123
+ cfg.lr = 5e-5
124
+ cfg.weight_decay = 0.0
125
+ cfg.betas = (0.9, 0.999)
126
+ cfg.eps = 1.0e-8
127
+ cfg.chunk_size = 16
128
+ cfg.alpha = 0.7
129
+ cfg.save_ckp_interval = 1000
130
+ # -----------------------------------------------------------------------------
131
+
132
+ # ----------------------------Pretrain Settings---------------------------------
133
+ # Default: load 2d pretrain
134
+ cfg.fix_weight = False
135
+ cfg.load_match = False
136
+ cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
137
+ cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
138
+ cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
139
+ # -----------------------------------------------------------------------------
140
+
141
+ # -----------------------------Visual-------------------------------------------
142
+ # Visual videos
143
+ cfg.viz_interval = 1000
144
+ cfg.visual_train = {
145
+ 'type': 'VisualVideoTextDuringTrain',
146
+ }
147
+ cfg.visual_inference = {
148
+ 'type': 'VisualGeneratedVideos',
149
+ }
150
+ cfg.inference_list_path = ''
151
+
152
+ # logging
153
+ cfg.log_interval = 100
154
+
155
+ # Default log_dir
156
+ cfg.log_dir = 'workspace/output_data'
157
+ # -----------------------------------------------------------------------------
158
+
159
+ # ---------------------------Others--------------------------------------------
160
+ # seed
161
+ cfg.seed = 8888
162
+
163
+ cfg.negative_prompt = 'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \
164
+ CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \
165
+ signature, jpeg artifacts, deformed, lowres, over-smooth'
166
+
167
+ cfg.positive_prompt = 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \
168
+ hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \
169
+ skin pore detailing, hyper sharpness, perfect without deformations.'
video_to_video/utils/logger.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import importlib
4
+ import logging
5
+ from typing import Optional
6
+ from torch import distributed as dist
7
+
8
+ init_loggers = {}
9
+
10
+ formatter = logging.Formatter(
11
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12
+
13
+
14
+ def get_logger(log_file: Optional[str] = None,
15
+ log_level: int = logging.INFO,
16
+ file_mode: str = 'w'):
17
+ """ Get logging logger
18
+
19
+ Args:
20
+ log_file: Log filename, if specified, file handler will be added to
21
+ logger
22
+ log_level: Logging level.
23
+ file_mode: Specifies the mode to open the file, if filename is
24
+ specified (if filemode is unspecified, it defaults to 'w').
25
+ """
26
+
27
+ logger_name = __name__.split('.')[0]
28
+ logger = logging.getLogger(logger_name)
29
+ logger.propagate = False
30
+ if logger_name in init_loggers:
31
+ add_file_handler_if_needed(logger, log_file, file_mode, log_level)
32
+ return logger
33
+
34
+ # handle duplicate logs to the console
35
+ # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
36
+ # to the root logger. As logger.propagate is True by default, this root
37
+ # level handler causes logging messages from rank>0 processes to
38
+ # unexpectedly show up on the console, creating much unwanted clutter.
39
+ # To fix this issue, we set the root logger's StreamHandler, if any, to log
40
+ # at the ERROR level.
41
+ for handler in logger.root.handlers:
42
+ if type(handler) is logging.StreamHandler:
43
+ handler.setLevel(logging.ERROR)
44
+
45
+ stream_handler = logging.StreamHandler()
46
+ handlers = [stream_handler]
47
+
48
+ if importlib.util.find_spec('torch') is not None:
49
+ is_worker0 = is_master()
50
+ else:
51
+ is_worker0 = True
52
+
53
+ if is_worker0 and log_file is not None:
54
+ file_handler = logging.FileHandler(log_file, file_mode)
55
+ handlers.append(file_handler)
56
+
57
+ for handler in handlers:
58
+ handler.setFormatter(formatter)
59
+ handler.setLevel(log_level)
60
+ logger.addHandler(handler)
61
+
62
+ if is_worker0:
63
+ logger.setLevel(log_level)
64
+ else:
65
+ logger.setLevel(logging.ERROR)
66
+
67
+ init_loggers[logger_name] = True
68
+
69
+ return logger
70
+
71
+
72
+ def add_file_handler_if_needed(logger, log_file, file_mode, log_level):
73
+ for handler in logger.handlers:
74
+ if isinstance(handler, logging.FileHandler):
75
+ return
76
+
77
+ if importlib.util.find_spec('torch') is not None:
78
+ is_worker0 = is_master()
79
+ else:
80
+ is_worker0 = True
81
+
82
+ if is_worker0 and log_file is not None:
83
+ file_handler = logging.FileHandler(log_file, file_mode)
84
+ file_handler.setFormatter(formatter)
85
+ file_handler.setLevel(log_level)
86
+ logger.addHandler(file_handler)
87
+
88
+
89
+ def is_master(group=None):
90
+ return dist.get_rank(group) == 0 if is_dist() else True
91
+
92
+
93
+ def is_dist():
94
+ return dist.is_available() and dist.is_initialized()
video_to_video/utils/seed.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ def setup_seed(seed):
10
+ torch.manual_seed(seed)
11
+ torch.cuda.manual_seed_all(seed)
12
+ np.random.seed(seed)
13
+ random.seed(seed)
14
+ torch.backends.cudnn.deterministic = True
video_to_video/video_to_video_model.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import random
4
+ from typing import Any, Dict
5
+
6
+ import torch
7
+ import torch.cuda.amp as amp
8
+ import torch.nn.functional as F
9
+
10
+ from video_to_video.modules import *
11
+ from video_to_video.utils.config import cfg
12
+ from video_to_video.diffusion.diffusion_sdedit import GaussianDiffusion
13
+ from video_to_video.diffusion.schedules_sdedit import noise_schedule
14
+ from video_to_video.utils.logger import get_logger
15
+
16
+ from diffusers import AutoencoderKLTemporalDecoder
17
+
18
+ logger = get_logger()
19
+
20
+ class VideoToVideo_sr():
21
+ def __init__(self, opt, device=torch.device(f'cuda:0')):
22
+ self.opt = opt
23
+ self.device = device # torch.device(f'cuda:0')
24
+
25
+ # text_encoder
26
+ text_encoder = FrozenOpenCLIPEmbedder(device=self.device, pretrained="laion2b_s32b_b79k")
27
+ text_encoder.model.to(self.device)
28
+ self.text_encoder = text_encoder
29
+ logger.info(f'Build encoder with FrozenOpenCLIPEmbedder')
30
+
31
+ # U-Net with ControlNet
32
+ generator = ControlledV2VUNet()
33
+ generator = generator.to(self.device)
34
+ generator.eval()
35
+
36
+ cfg.model_path = opt.model_path
37
+ load_dict = torch.load(cfg.model_path, map_location='cpu')
38
+ if 'state_dict' in load_dict:
39
+ load_dict = load_dict['state_dict']
40
+ ret = generator.load_state_dict(load_dict, strict=False)
41
+
42
+ self.generator = generator.half()
43
+ logger.info('Load model path {}, with local status {}'.format(cfg.model_path, ret))
44
+
45
+ # Noise scheduler
46
+ sigmas = noise_schedule(
47
+ schedule='logsnr_cosine_interp',
48
+ n=1000,
49
+ zero_terminal_snr=True,
50
+ scale_min=2.0,
51
+ scale_max=4.0)
52
+ diffusion = GaussianDiffusion(sigmas=sigmas)
53
+ self.diffusion = diffusion
54
+ logger.info('Build diffusion with GaussianDiffusion')
55
+
56
+ # Temporal VAE
57
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
58
+ "stabilityai/stable-video-diffusion-img2vid", subfolder="vae", variant="fp16"
59
+ )
60
+ vae.eval()
61
+ vae.requires_grad_(False)
62
+ vae.to(self.device)
63
+ self.vae = vae
64
+ logger.info('Build Temporal VAE')
65
+
66
+ torch.cuda.empty_cache()
67
+
68
+ self.negative_prompt = cfg.negative_prompt
69
+ self.positive_prompt = cfg.positive_prompt
70
+
71
+ negative_y = text_encoder(self.negative_prompt).detach()
72
+ self.negative_y = negative_y
73
+
74
+
75
+ def test(self, input: Dict[str, Any], total_noise_levels=1000, \
76
+ steps=50, solver_mode='fast', guide_scale=7.5, max_chunk_len=32):
77
+ video_data = input['video_data']
78
+ y = input['y']
79
+ (target_h, target_w) = input['target_res']
80
+
81
+ video_data = F.interpolate(video_data, [target_h,target_w], mode='bilinear')
82
+
83
+ logger.info(f'video_data shape: {video_data.shape}')
84
+ frames_num, _, h, w = video_data.shape
85
+
86
+ padding = pad_to_fit(h, w)
87
+ video_data = F.pad(video_data, padding, 'constant', 1)
88
+
89
+ video_data = video_data.unsqueeze(0)
90
+ bs = 1
91
+ video_data = video_data.to(self.device)
92
+
93
+ video_data_feature = self.vae_encode(video_data)
94
+ torch.cuda.empty_cache()
95
+
96
+ y = self.text_encoder(y).detach()
97
+
98
+ with amp.autocast(enabled=True):
99
+
100
+ t = torch.LongTensor([total_noise_levels-1]).to(self.device)
101
+ noised_lr = self.diffusion.diffuse(video_data_feature, t)
102
+
103
+ model_kwargs = [{'y': y}, {'y': self.negative_y}]
104
+ model_kwargs.append({'hint': video_data_feature})
105
+
106
+ torch.cuda.empty_cache()
107
+ chunk_inds = make_chunks(frames_num, interp_f_num=0, max_chunk_len=max_chunk_len) if frames_num > max_chunk_len else None
108
+
109
+ solver = 'dpmpp_2m_sde' # 'heun' | 'dpmpp_2m_sde'
110
+ gen_vid = self.diffusion.sample_sr(
111
+ noise=noised_lr,
112
+ model=self.generator,
113
+ model_kwargs=model_kwargs,
114
+ guide_scale=guide_scale,
115
+ guide_rescale=0.2,
116
+ solver=solver,
117
+ solver_mode=solver_mode,
118
+ return_intermediate=None,
119
+ steps=steps,
120
+ t_max=total_noise_levels - 1,
121
+ t_min=0,
122
+ discretization='trailing',
123
+ chunk_inds=chunk_inds,)
124
+ torch.cuda.empty_cache()
125
+
126
+ logger.info(f'sampling, finished.')
127
+ vid_tensor_gen = self.vae_decode_chunk(gen_vid, chunk_size=3)
128
+
129
+ logger.info(f'temporal vae decoding, finished.')
130
+
131
+ w1, w2, h1, h2 = padding
132
+ vid_tensor_gen = vid_tensor_gen[:,:,h1:h+h1,w1:w+w1]
133
+
134
+ gen_video = rearrange(
135
+ vid_tensor_gen, '(b f) c h w -> b c f h w', b=bs)
136
+
137
+ torch.cuda.empty_cache()
138
+
139
+ return gen_video.type(torch.float32).cpu()
140
+
141
+ def temporal_vae_decode(self, z, num_f):
142
+ return self.vae.decode(z/self.vae.config.scaling_factor, num_frames=num_f).sample
143
+
144
+ def vae_decode_chunk(self, z, chunk_size=3):
145
+ z = rearrange(z, "b c f h w -> (b f) c h w")
146
+ video = []
147
+ for ind in range(0, z.shape[0], chunk_size):
148
+ num_f = z[ind:ind+chunk_size].shape[0]
149
+ video.append(self.temporal_vae_decode(z[ind:ind+chunk_size],num_f))
150
+ video = torch.cat(video)
151
+ return video
152
+
153
+ def vae_encode(self, t, chunk_size=1):
154
+ num_f = t.shape[1]
155
+ t = rearrange(t, "b f c h w -> (b f) c h w")
156
+ z_list = []
157
+ for ind in range(0,t.shape[0],chunk_size):
158
+ z_list.append(self.vae.encode(t[ind:ind+chunk_size]).latent_dist.sample())
159
+ z = torch.cat(z_list, dim=0)
160
+ z = rearrange(z, "(b f) c h w -> b c f h w", f=num_f)
161
+ return z * self.vae.config.scaling_factor
162
+
163
+
164
+ def pad_to_fit(h, w):
165
+ BEST_H, BEST_W = 720, 1280
166
+
167
+ if h < BEST_H:
168
+ h1, h2 = _create_pad(h, BEST_H)
169
+ elif h == BEST_H:
170
+ h1 = h2 = 0
171
+ else:
172
+ h1 = 0
173
+ h2 = int((h + 48) // 64 * 64) + 64 - 48 - h
174
+
175
+ if w < BEST_W:
176
+ w1, w2 = _create_pad(w, BEST_W)
177
+ elif w == BEST_W:
178
+ w1 = w2 = 0
179
+ else:
180
+ w1 = 0
181
+ w2 = int(w // 64 * 64) + 64 - w
182
+ return (w1, w2, h1, h2)
183
+
184
+ def _create_pad(h, max_len):
185
+ h1 = int((max_len - h) // 2)
186
+ h2 = max_len - h1 - h
187
+ return h1, h2
188
+
189
+
190
+ def make_chunks(f_num, interp_f_num, max_chunk_len, chunk_overlap_ratio=0.5):
191
+ MAX_CHUNK_LEN = max_chunk_len
192
+ MAX_O_LEN = MAX_CHUNK_LEN * chunk_overlap_ratio
193
+ chunk_len = int((MAX_CHUNK_LEN-1)//(1+interp_f_num)*(interp_f_num+1)+1)
194
+ o_len = int((MAX_O_LEN-1)//(1+interp_f_num)*(interp_f_num+1)+1)
195
+ chunk_inds = sliding_windows_1d(f_num, chunk_len, o_len)
196
+ return chunk_inds
197
+
198
+
199
+ def sliding_windows_1d(length, window_size, overlap_size):
200
+ stride = window_size - overlap_size
201
+ ind = 0
202
+ coords = []
203
+ while ind<length:
204
+ if ind+window_size*1.25>=length:
205
+ coords.append((ind,length))
206
+ break
207
+ else:
208
+ coords.append((ind,ind+window_size))
209
+ ind += stride
210
+ return coords