import os import os.path as osp import random from typing import Any, Dict import torch import torch.cuda.amp as amp import torch.nn.functional as F from video_to_video.modules import * from video_to_video.utils.config import cfg from video_to_video.diffusion.diffusion_sdedit import GaussianDiffusion from video_to_video.diffusion.schedules_sdedit import noise_schedule from video_to_video.utils.logger import get_logger from diffusers import AutoencoderKLTemporalDecoder import requests def download_model(url, model_path): if not os.path.exists(model_path): print(f"Model not found at {model_path}, downloading...") response = requests.get(url, stream=True) with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=1024): if chunk: f.write(chunk) print(f"Model downloaded to {model_path}") else: print(f"Model found at {model_path}, skipping download.") logger = get_logger() class VideoToVideo_sr(): def __init__(self, opt, device=torch.device(f'cuda:0')): self.opt = opt self.device = device # torch.device(f'cuda:0') # text_encoder text_encoder = FrozenOpenCLIPEmbedder(device=self.device, pretrained="laion2b_s32b_b79k") text_encoder.model.to(self.device) self.text_encoder = text_encoder logger.info(f'Build encoder with FrozenOpenCLIPEmbedder') # U-Net with ControlNet generator = ControlledV2VUNet() generator = generator.to(self.device) generator.eval() cfg.model_path = opt.model_path # download weight model_url = 'https://huggingface.co/SherryX/STAR/resolve/main/I2VGen-XL-based/heavy_deg.pt' download_model(model_url, cfg.model_path) load_dict = torch.load(cfg.model_path, map_location='cpu') if 'state_dict' in load_dict: load_dict = load_dict['state_dict'] ret = generator.load_state_dict(load_dict, strict=False) self.generator = generator.half() logger.info('Load model path {}, with local status {}'.format(cfg.model_path, ret)) # Noise scheduler sigmas = noise_schedule( schedule='logsnr_cosine_interp', n=1000, zero_terminal_snr=True, scale_min=2.0, scale_max=4.0) diffusion = GaussianDiffusion(sigmas=sigmas) self.diffusion = diffusion logger.info('Build diffusion with GaussianDiffusion') # Temporal VAE vae = AutoencoderKLTemporalDecoder.from_pretrained( "stabilityai/stable-video-diffusion-img2vid", subfolder="vae", variant="fp16" ) vae.eval() vae.requires_grad_(False) vae.to(self.device) self.vae = vae logger.info('Build Temporal VAE') torch.cuda.empty_cache() self.negative_prompt = cfg.negative_prompt self.positive_prompt = cfg.positive_prompt negative_y = text_encoder(self.negative_prompt).detach() self.negative_y = negative_y def test(self, input: Dict[str, Any], total_noise_levels=1000, \ steps=50, solver_mode='fast', guide_scale=7.5, max_chunk_len=32): video_data = input['video_data'] y = input['y'] (target_h, target_w) = input['target_res'] video_data = F.interpolate(video_data, [target_h,target_w], mode='bilinear') logger.info(f'video_data shape: {video_data.shape}') frames_num, _, h, w = video_data.shape padding = pad_to_fit(h, w) video_data = F.pad(video_data, padding, 'constant', 1) video_data = video_data.unsqueeze(0) bs = 1 video_data = video_data.to(self.device) video_data_feature = self.vae_encode(video_data) torch.cuda.empty_cache() y = self.text_encoder(y).detach() with amp.autocast(enabled=True): t = torch.LongTensor([total_noise_levels-1]).to(self.device) noised_lr = self.diffusion.diffuse(video_data_feature, t) model_kwargs = [{'y': y}, {'y': self.negative_y}] model_kwargs.append({'hint': video_data_feature}) torch.cuda.empty_cache() chunk_inds = make_chunks(frames_num, interp_f_num=0, max_chunk_len=max_chunk_len) if frames_num > max_chunk_len else None solver = 'dpmpp_2m_sde' # 'heun' | 'dpmpp_2m_sde' gen_vid = self.diffusion.sample_sr( noise=noised_lr, model=self.generator, model_kwargs=model_kwargs, guide_scale=guide_scale, guide_rescale=0.2, solver=solver, solver_mode=solver_mode, return_intermediate=None, steps=steps, t_max=total_noise_levels - 1, t_min=0, discretization='trailing', chunk_inds=chunk_inds,) torch.cuda.empty_cache() logger.info(f'sampling, finished.') vid_tensor_gen = self.vae_decode_chunk(gen_vid, chunk_size=3) logger.info(f'temporal vae decoding, finished.') w1, w2, h1, h2 = padding vid_tensor_gen = vid_tensor_gen[:,:,h1:h+h1,w1:w+w1] gen_video = rearrange( vid_tensor_gen, '(b f) c h w -> b c f h w', b=bs) torch.cuda.empty_cache() return gen_video.type(torch.float32).cpu() def temporal_vae_decode(self, z, num_f): return self.vae.decode(z/self.vae.config.scaling_factor, num_frames=num_f).sample def vae_decode_chunk(self, z, chunk_size=3): z = rearrange(z, "b c f h w -> (b f) c h w") video = [] for ind in range(0, z.shape[0], chunk_size): num_f = z[ind:ind+chunk_size].shape[0] video.append(self.temporal_vae_decode(z[ind:ind+chunk_size],num_f)) video = torch.cat(video) return video def vae_encode(self, t, chunk_size=1): num_f = t.shape[1] t = rearrange(t, "b f c h w -> (b f) c h w") z_list = [] for ind in range(0,t.shape[0],chunk_size): z_list.append(self.vae.encode(t[ind:ind+chunk_size]).latent_dist.sample()) z = torch.cat(z_list, dim=0) z = rearrange(z, "(b f) c h w -> b c f h w", f=num_f) return z * self.vae.config.scaling_factor def pad_to_fit(h, w): BEST_H, BEST_W = 720, 1280 if h < BEST_H: h1, h2 = _create_pad(h, BEST_H) elif h == BEST_H: h1 = h2 = 0 else: h1 = 0 h2 = int((h + 48) // 64 * 64) + 64 - 48 - h if w < BEST_W: w1, w2 = _create_pad(w, BEST_W) elif w == BEST_W: w1 = w2 = 0 else: w1 = 0 w2 = int(w // 64 * 64) + 64 - w return (w1, w2, h1, h2) def _create_pad(h, max_len): h1 = int((max_len - h) // 2) h2 = max_len - h1 - h return h1, h2 def make_chunks(f_num, interp_f_num, max_chunk_len, chunk_overlap_ratio=0.5): MAX_CHUNK_LEN = max_chunk_len MAX_O_LEN = MAX_CHUNK_LEN * chunk_overlap_ratio chunk_len = int((MAX_CHUNK_LEN-1)//(1+interp_f_num)*(interp_f_num+1)+1) o_len = int((MAX_O_LEN-1)//(1+interp_f_num)*(interp_f_num+1)+1) chunk_inds = sliding_windows_1d(f_num, chunk_len, o_len) return chunk_inds def sliding_windows_1d(length, window_size, overlap_size): stride = window_size - overlap_size ind = 0 coords = [] while ind=length: coords.append((ind,length)) break else: coords.append((ind,ind+window_size)) ind += stride return coords