Hans commited on
Commit
96e9589
0 Parent(s):

Diffusers-compatible TemporalNet2 checkpoint and inference script

Browse files
config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.18.0.dev0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": 8,
6
+ "block_out_channels": [
7
+ 320,
8
+ 640,
9
+ 1280,
10
+ 1280
11
+ ],
12
+ "class_embed_type": null,
13
+ "conditioning_channels": 6,
14
+ "conditioning_embedding_out_channels": [
15
+ 16,
16
+ 32,
17
+ 96,
18
+ 256
19
+ ],
20
+ "controlnet_conditioning_channel_order": "rgb",
21
+ "cross_attention_dim": 768,
22
+ "down_block_types": [
23
+ "CrossAttnDownBlock2D",
24
+ "CrossAttnDownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "DownBlock2D"
27
+ ],
28
+ "downsample_padding": 1,
29
+ "flip_sin_to_cos": true,
30
+ "freq_shift": 0,
31
+ "global_pool_conditions": false,
32
+ "in_channels": 4,
33
+ "layers_per_block": 2,
34
+ "mid_block_scale_factor": 1,
35
+ "norm_eps": 1e-05,
36
+ "norm_num_groups": 32,
37
+ "num_class_embeds": null,
38
+ "only_cross_attention": false,
39
+ "projection_class_embeddings_input_dim": null,
40
+ "resnet_time_scale_shift": "default",
41
+ "upcast_attention": false,
42
+ "use_linear_projection": false
43
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b31fdb59df59d2951354b143bb292de50c01e971aa8b83d70eb3c4e54cdcd7a2
3
+ size 1445158852
temporalvideo_hf.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from diffusers import ControlNetModel, DPMSolverMultistepScheduler, StableDiffusionControlNetImg2ImgPipeline
7
+ from torch import Tensor
8
+ from torchvision.io.video import read_video, write_video
9
+ from torchvision.models.optical_flow import Raft_Large_Weights, raft_large
10
+ from torchvision.transforms.functional import resize
11
+ from torchvision.utils import flow_to_image
12
+ from tqdm import trange
13
+
14
+ raft_transform = Raft_Large_Weights.DEFAULT.transforms()
15
+
16
+
17
+ @torch.inference_mode()
18
+ def stylize_video(
19
+ input_video: Tensor,
20
+ prompt: str,
21
+ strength: float = 0.7,
22
+ num_steps: int = 20,
23
+ guidance_scale: float = 7.5,
24
+ controlnet_scale: float = 1.0,
25
+ batch_size: int = 4,
26
+ height: int = 512,
27
+ width: int = 512,
28
+ device: str = "cuda",
29
+ ) -> Tensor:
30
+ """
31
+ Stylize a video with temporal coherence (less flickering!) using HuggingFace's Stable Diffusion ControlNet pipeline.
32
+
33
+ Args:
34
+ input_video (Tensor): Input video tensor of shape (T, C, H, W) and range [0, 1].
35
+ prompt (str): Text prompt to condition the diffusion process.
36
+ strength (float, optional): How heavily stylization affects the image.
37
+ num_steps (int, optional): Number of diffusion steps (tradeoff between quality and speed).
38
+ guidance_scale (float, optional): Scale of the text guidance loss (how closely to adhere to text prompt).
39
+ controlnet_scale (float, optional): Scale of the ControlNet conditioning (strength of temporal coherence).
40
+ batch_size (int, optional): Number of frames to diffuse at once (faster but more memory intensive).
41
+ height (int, optional): Height of the output video.
42
+ width (int, optional): Width of the output video.
43
+ device (str, optional): Device to run stylization process on.
44
+
45
+ Returns:
46
+ Tensor: Output video tensor of shape (T, C, H, W) and range [0, 1].
47
+ """
48
+
49
+ with warnings.catch_warnings():
50
+ warnings.simplefilter("ignore") # silence annoying TypedStorage warnings
51
+
52
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
53
+ "runwayml/stable-diffusion-v1-5",
54
+ controlnet=ControlNetModel.from_pretrained("wav/TemporalNet2", torch_dtype=torch.float16),
55
+ safety_checker=None,
56
+ torch_dtype=torch.float16,
57
+ ).to(device)
58
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
59
+ pipe.enable_xformers_memory_efficient_attention()
60
+ pipe._progress_bar_config = dict(disable=True)
61
+
62
+ raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=True).eval().to(device)
63
+
64
+ output_video = []
65
+ for i in trange(1, len(input_video), batch_size, desc="Diffusing...", unit="frame", unit_scale=batch_size):
66
+ prev = resize(input_video[i - 1 : i - 1 + batch_size], (height, width), antialias=True).to(device)
67
+ curr = resize(input_video[i : i + batch_size], (height, width), antialias=True).to(device)
68
+ prev = prev[: curr.shape[0]] # make sure prev and curr have the same batch size (for the last batch)
69
+
70
+ flow_img = flow_to_image(raft.forward(*raft_transform(prev, curr))[-1]).div(255)
71
+ control_img = torch.cat((prev, flow_img), dim=1)
72
+
73
+ output, _ = pipe(
74
+ prompt=[prompt] * curr.shape[0],
75
+ image=curr,
76
+ control_image=control_img,
77
+ height=height,
78
+ width=width,
79
+ strength=strength,
80
+ num_inference_steps=num_steps,
81
+ guidance_scale=guidance_scale,
82
+ controlnet_conditioning_scale=controlnet_scale,
83
+ output_type="pt",
84
+ return_dict=False,
85
+ )
86
+
87
+ output_video.append(output.permute(0, 2, 3, 1).cpu())
88
+
89
+ return torch.cat(output_video)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ parser = argparse.ArgumentParser(usage=stylize_video.__doc__)
94
+ parser.add_argument("-i", "--in-file", type=str, required=True)
95
+ parser.add_argument("-p", "--prompt", type=str, required=True)
96
+ parser.add_argument("-o", "--out-file", type=str, default=None)
97
+ parser.add_argument("-s", "--strength", type=float, default=0.7)
98
+ parser.add_argument("-S", "--num-steps", type=int, default=20)
99
+ parser.add_argument("-g", "--guidance-scale", type=float, default=7.5)
100
+ parser.add_argument("-c", "--controlnet-scale", type=float, default=1.0)
101
+ parser.add_argument("-b", "--batch_size", type=int, default=4)
102
+ parser.add_argument("-H", "--height", type=int, default=512)
103
+ parser.add_argument("-W", "--width", type=int, default=512)
104
+ parser.add_argument("-d", "--device", type=str, default="cuda")
105
+ args = parser.parse_args()
106
+
107
+ input_video, _, info = read_video(args.in_file, pts_unit="sec", output_format="TCHW")
108
+ input_video = input_video.div(255)
109
+
110
+ output_video = stylize_video(
111
+ input_video=input_video,
112
+ prompt=args.prompt,
113
+ strength=args.strength,
114
+ num_steps=args.num_steps,
115
+ guidance_scale=args.guidance_scale,
116
+ controlnet_scale=args.controlnet_scale,
117
+ height=args.height,
118
+ width=args.width,
119
+ device=args.device,
120
+ batch_size=args.batch_size,
121
+ )
122
+
123
+ out_file = f"{Path(args.in_file).stem} | {args.prompt}.mp4" if args.out_file is None else args.out_file
124
+ write_video(out_file, output_video.mul(255), fps=info["video_fps"])