Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
from argparse import ArgumentParser, Namespace | |
import json | |
from typing import Any, Dict, List, Mapping, Tuple | |
from easydict import EasyDict | |
from video_to_video.video_to_video_model import VideoToVideo_sr | |
from video_to_video.utils.seed import setup_seed | |
from video_to_video.utils.logger import get_logger | |
from video_super_resolution.color_fix import adain_color_fix | |
from inference_utils import * | |
logger = get_logger() | |
class VEnhancer_sr(): | |
def __init__(self, | |
result_dir='./results/', | |
file_name='000_video.mp4', | |
model_path='', | |
solver_mode='fast', | |
steps=15, | |
guide_scale=7.5, | |
upscale=4, | |
max_chunk_len=32, | |
variant_info=None, | |
): | |
self.model_path=model_path | |
logger.info('checkpoint_path: {}'.format(self.model_path)) | |
self.result_dir = result_dir | |
self.file_name = file_name | |
os.makedirs(self.result_dir, exist_ok=True) | |
model_cfg = EasyDict(__name__='model_cfg') | |
model_cfg.model_path = self.model_path | |
self.model = VideoToVideo_sr(model_cfg) | |
steps = 15 if solver_mode == 'fast' else steps | |
self.solver_mode=solver_mode | |
self.steps=steps | |
self.guide_scale=guide_scale | |
self.upscale = upscale | |
self.max_chunk_len=max_chunk_len | |
self.variant_info=variant_info | |
def enhance_a_video(self, video_path, prompt): | |
logger.info('input video path: {}'.format(video_path)) | |
text = prompt | |
logger.info('text: {}'.format(text)) | |
caption = text + self.model.positive_prompt | |
input_frames, input_fps = load_video(video_path) | |
in_f_num = len(input_frames) | |
logger.info('input frames length: {}'.format(in_f_num)) | |
logger.info('input fps: {}'.format(input_fps)) | |
video_data = preprocess(input_frames) | |
_, _, h, w = video_data.shape | |
logger.info('input resolution: {}'.format((h, w))) | |
target_h, target_w = h * self.upscale, w * self.upscale # adjust_resolution(h, w, up_scale=4) | |
logger.info('target resolution: {}'.format((target_h, target_w))) | |
pre_data = {'video_data': video_data, 'y': caption} | |
pre_data['target_res'] = (target_h, target_w) | |
total_noise_levels = 900 | |
setup_seed(666) | |
with torch.no_grad(): | |
data_tensor = collate_fn(pre_data, 'cuda:0') | |
output = self.model.test(data_tensor, total_noise_levels, steps=self.steps, \ | |
solver_mode=self.solver_mode, guide_scale=self.guide_scale, \ | |
max_chunk_len=self.max_chunk_len | |
) | |
output = tensor2vid(output) | |
# Using color fix | |
output = adain_color_fix(output, video_data) | |
save_video(output, self.result_dir, self.file_name, fps=input_fps) | |
return os.path.join(self.result_dir, self.file_name) | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument("--input_path", required=True, type=str, help="input video path") | |
parser.add_argument("--save_dir", type=str, default='results', help="save directory") | |
parser.add_argument("--file_name", type=str, help="file name") | |
parser.add_argument("--model_path", type=str, default='./pretrained_weight/model.pt', help="model path") | |
parser.add_argument("--prompt", type=str, default='a good video', help="prompt") | |
parser.add_argument("--upscale", type=int, default=4, help='up-scale') | |
parser.add_argument("--max_chunk_len", type=int, default=32, help='max_chunk_len') | |
parser.add_argument("--variant_info", type=str, default=None, help='information of inference strategy') | |
parser.add_argument("--cfg", type=float, default=7.5) | |
parser.add_argument("--solver_mode", type=str, default='fast', help='fast | normal') | |
parser.add_argument("--steps", type=int, default=15) | |
return parser.parse_args() | |
def main(): | |
args = parse_args() | |
input_path = args.input_path | |
prompt = args.prompt | |
model_path = args.model_path | |
save_dir = args.save_dir | |
file_name = args.file_name | |
upscale = args.upscale | |
max_chunk_len = args.max_chunk_len | |
steps = args.steps | |
solver_mode = args.solver_mode | |
guide_scale = args.cfg | |
assert solver_mode in ('fast', 'normal') | |
venhancer_sr = VEnhancer_sr( | |
result_dir=save_dir, | |
file_name=file_name, # new added | |
model_path=model_path, | |
solver_mode=solver_mode, | |
steps=steps, | |
guide_scale=guide_scale, | |
upscale=upscale, | |
max_chunk_len=max_chunk_len, | |
variant_info=None, | |
) | |
venhancer_sr.enhance_a_video(input_path, prompt) | |
if __name__ == '__main__': | |
main() | |