Hans
commited on
Commit
•
96e9589
0
Parent(s):
Diffusers-compatible TemporalNet2 checkpoint and inference script
Browse files- config.json +43 -0
- diffusion_pytorch_model.safetensors +3 -0
- temporalvideo_hf.py +124 -0
config.json
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "ControlNetModel",
|
3 |
+
"_diffusers_version": "0.18.0.dev0",
|
4 |
+
"act_fn": "silu",
|
5 |
+
"attention_head_dim": 8,
|
6 |
+
"block_out_channels": [
|
7 |
+
320,
|
8 |
+
640,
|
9 |
+
1280,
|
10 |
+
1280
|
11 |
+
],
|
12 |
+
"class_embed_type": null,
|
13 |
+
"conditioning_channels": 6,
|
14 |
+
"conditioning_embedding_out_channels": [
|
15 |
+
16,
|
16 |
+
32,
|
17 |
+
96,
|
18 |
+
256
|
19 |
+
],
|
20 |
+
"controlnet_conditioning_channel_order": "rgb",
|
21 |
+
"cross_attention_dim": 768,
|
22 |
+
"down_block_types": [
|
23 |
+
"CrossAttnDownBlock2D",
|
24 |
+
"CrossAttnDownBlock2D",
|
25 |
+
"CrossAttnDownBlock2D",
|
26 |
+
"DownBlock2D"
|
27 |
+
],
|
28 |
+
"downsample_padding": 1,
|
29 |
+
"flip_sin_to_cos": true,
|
30 |
+
"freq_shift": 0,
|
31 |
+
"global_pool_conditions": false,
|
32 |
+
"in_channels": 4,
|
33 |
+
"layers_per_block": 2,
|
34 |
+
"mid_block_scale_factor": 1,
|
35 |
+
"norm_eps": 1e-05,
|
36 |
+
"norm_num_groups": 32,
|
37 |
+
"num_class_embeds": null,
|
38 |
+
"only_cross_attention": false,
|
39 |
+
"projection_class_embeddings_input_dim": null,
|
40 |
+
"resnet_time_scale_shift": "default",
|
41 |
+
"upcast_attention": false,
|
42 |
+
"use_linear_projection": false
|
43 |
+
}
|
diffusion_pytorch_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b31fdb59df59d2951354b143bb292de50c01e971aa8b83d70eb3c4e54cdcd7a2
|
3 |
+
size 1445158852
|
temporalvideo_hf.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import warnings
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from diffusers import ControlNetModel, DPMSolverMultistepScheduler, StableDiffusionControlNetImg2ImgPipeline
|
7 |
+
from torch import Tensor
|
8 |
+
from torchvision.io.video import read_video, write_video
|
9 |
+
from torchvision.models.optical_flow import Raft_Large_Weights, raft_large
|
10 |
+
from torchvision.transforms.functional import resize
|
11 |
+
from torchvision.utils import flow_to_image
|
12 |
+
from tqdm import trange
|
13 |
+
|
14 |
+
raft_transform = Raft_Large_Weights.DEFAULT.transforms()
|
15 |
+
|
16 |
+
|
17 |
+
@torch.inference_mode()
|
18 |
+
def stylize_video(
|
19 |
+
input_video: Tensor,
|
20 |
+
prompt: str,
|
21 |
+
strength: float = 0.7,
|
22 |
+
num_steps: int = 20,
|
23 |
+
guidance_scale: float = 7.5,
|
24 |
+
controlnet_scale: float = 1.0,
|
25 |
+
batch_size: int = 4,
|
26 |
+
height: int = 512,
|
27 |
+
width: int = 512,
|
28 |
+
device: str = "cuda",
|
29 |
+
) -> Tensor:
|
30 |
+
"""
|
31 |
+
Stylize a video with temporal coherence (less flickering!) using HuggingFace's Stable Diffusion ControlNet pipeline.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
input_video (Tensor): Input video tensor of shape (T, C, H, W) and range [0, 1].
|
35 |
+
prompt (str): Text prompt to condition the diffusion process.
|
36 |
+
strength (float, optional): How heavily stylization affects the image.
|
37 |
+
num_steps (int, optional): Number of diffusion steps (tradeoff between quality and speed).
|
38 |
+
guidance_scale (float, optional): Scale of the text guidance loss (how closely to adhere to text prompt).
|
39 |
+
controlnet_scale (float, optional): Scale of the ControlNet conditioning (strength of temporal coherence).
|
40 |
+
batch_size (int, optional): Number of frames to diffuse at once (faster but more memory intensive).
|
41 |
+
height (int, optional): Height of the output video.
|
42 |
+
width (int, optional): Width of the output video.
|
43 |
+
device (str, optional): Device to run stylization process on.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
Tensor: Output video tensor of shape (T, C, H, W) and range [0, 1].
|
47 |
+
"""
|
48 |
+
|
49 |
+
with warnings.catch_warnings():
|
50 |
+
warnings.simplefilter("ignore") # silence annoying TypedStorage warnings
|
51 |
+
|
52 |
+
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
53 |
+
"runwayml/stable-diffusion-v1-5",
|
54 |
+
controlnet=ControlNetModel.from_pretrained("wav/TemporalNet2", torch_dtype=torch.float16),
|
55 |
+
safety_checker=None,
|
56 |
+
torch_dtype=torch.float16,
|
57 |
+
).to(device)
|
58 |
+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
59 |
+
pipe.enable_xformers_memory_efficient_attention()
|
60 |
+
pipe._progress_bar_config = dict(disable=True)
|
61 |
+
|
62 |
+
raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=True).eval().to(device)
|
63 |
+
|
64 |
+
output_video = []
|
65 |
+
for i in trange(1, len(input_video), batch_size, desc="Diffusing...", unit="frame", unit_scale=batch_size):
|
66 |
+
prev = resize(input_video[i - 1 : i - 1 + batch_size], (height, width), antialias=True).to(device)
|
67 |
+
curr = resize(input_video[i : i + batch_size], (height, width), antialias=True).to(device)
|
68 |
+
prev = prev[: curr.shape[0]] # make sure prev and curr have the same batch size (for the last batch)
|
69 |
+
|
70 |
+
flow_img = flow_to_image(raft.forward(*raft_transform(prev, curr))[-1]).div(255)
|
71 |
+
control_img = torch.cat((prev, flow_img), dim=1)
|
72 |
+
|
73 |
+
output, _ = pipe(
|
74 |
+
prompt=[prompt] * curr.shape[0],
|
75 |
+
image=curr,
|
76 |
+
control_image=control_img,
|
77 |
+
height=height,
|
78 |
+
width=width,
|
79 |
+
strength=strength,
|
80 |
+
num_inference_steps=num_steps,
|
81 |
+
guidance_scale=guidance_scale,
|
82 |
+
controlnet_conditioning_scale=controlnet_scale,
|
83 |
+
output_type="pt",
|
84 |
+
return_dict=False,
|
85 |
+
)
|
86 |
+
|
87 |
+
output_video.append(output.permute(0, 2, 3, 1).cpu())
|
88 |
+
|
89 |
+
return torch.cat(output_video)
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
parser = argparse.ArgumentParser(usage=stylize_video.__doc__)
|
94 |
+
parser.add_argument("-i", "--in-file", type=str, required=True)
|
95 |
+
parser.add_argument("-p", "--prompt", type=str, required=True)
|
96 |
+
parser.add_argument("-o", "--out-file", type=str, default=None)
|
97 |
+
parser.add_argument("-s", "--strength", type=float, default=0.7)
|
98 |
+
parser.add_argument("-S", "--num-steps", type=int, default=20)
|
99 |
+
parser.add_argument("-g", "--guidance-scale", type=float, default=7.5)
|
100 |
+
parser.add_argument("-c", "--controlnet-scale", type=float, default=1.0)
|
101 |
+
parser.add_argument("-b", "--batch_size", type=int, default=4)
|
102 |
+
parser.add_argument("-H", "--height", type=int, default=512)
|
103 |
+
parser.add_argument("-W", "--width", type=int, default=512)
|
104 |
+
parser.add_argument("-d", "--device", type=str, default="cuda")
|
105 |
+
args = parser.parse_args()
|
106 |
+
|
107 |
+
input_video, _, info = read_video(args.in_file, pts_unit="sec", output_format="TCHW")
|
108 |
+
input_video = input_video.div(255)
|
109 |
+
|
110 |
+
output_video = stylize_video(
|
111 |
+
input_video=input_video,
|
112 |
+
prompt=args.prompt,
|
113 |
+
strength=args.strength,
|
114 |
+
num_steps=args.num_steps,
|
115 |
+
guidance_scale=args.guidance_scale,
|
116 |
+
controlnet_scale=args.controlnet_scale,
|
117 |
+
height=args.height,
|
118 |
+
width=args.width,
|
119 |
+
device=args.device,
|
120 |
+
batch_size=args.batch_size,
|
121 |
+
)
|
122 |
+
|
123 |
+
out_file = f"{Path(args.in_file).stem} | {args.prompt}.mp4" if args.out_file is None else args.out_file
|
124 |
+
write_video(out_file, output_video.mul(255), fps=info["video_fps"])
|