FRESCO / run_fresco.py
SingleZombie
upload files
ff715ca
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = "6"
# In China, set this to use huggingface
# os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import cv2
import io
import gc
import yaml
import argparse
import torch
import torchvision
import diffusers
from diffusers import StableDiffusionPipeline, AutoencoderKL, DDPMScheduler, ControlNetModel
from src.utils import *
from src.keyframe_selection import get_keyframe_ind
from src.diffusion_hacked import apply_FRESCO_attn, apply_FRESCO_opt, disable_FRESCO_opt
from src.diffusion_hacked import get_flow_and_interframe_paras, get_intraframe_paras
from src.pipe_FRESCO import inference
def get_models(config):
print('\n' + '=' * 100)
print('creating models...')
import sys
sys.path.append("./src/ebsynth/deps/gmflow/")
sys.path.append("./src/EGNet/")
sys.path.append("./src/ControlNet/")
from gmflow.gmflow import GMFlow
from model import build_model
from annotator.hed import HEDdetector
from annotator.canny import CannyDetector
from annotator.midas import MidasDetector
# optical flow
flow_model = GMFlow(feature_channels=128,
num_scales=1,
upsample_factor=8,
num_head=1,
attention_type='swin',
ffn_dim_expansion=4,
num_transformer_layers=6,
).to('cuda')
checkpoint = torch.load(config['gmflow_path'], map_location=lambda storage, loc: storage)
weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
flow_model.load_state_dict(weights, strict=False)
flow_model.eval()
print('create optical flow estimation model successfully!')
# saliency detection
sod_model = build_model('resnet')
sod_model.load_state_dict(torch.load(config['sod_path']))
sod_model.to("cuda").eval()
print('create saliency detection model successfully!')
# controlnet
if config['controlnet_type'] not in ['hed', 'depth', 'canny']:
print('unsupported control type, set to hed')
config['controlnet_type'] = 'hed'
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-"+config['controlnet_type'],
torch_dtype=torch.float16)
controlnet.to("cuda")
if config['controlnet_type'] == 'depth':
detector = MidasDetector()
elif config['controlnet_type'] == 'canny':
detector = CannyDetector()
else:
detector = HEDdetector()
print('create controlnet model-' + config['controlnet_type'] + ' successfully!')
# diffusion model
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(config['sd_path'], vae=vae, torch_dtype=torch.float16)
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
#noise_scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
pipe.to("cuda")
pipe.scheduler.set_timesteps(config['num_inference_steps'], device=pipe._execution_device)
if config['use_freeu']:
from src.free_lunch_utils import apply_freeu
apply_freeu(pipe, b1=1.2, b2=1.5, s1=1.0, s2=1.0)
frescoProc = apply_FRESCO_attn(pipe)
frescoProc.controller.disable_controller()
apply_FRESCO_opt(pipe)
print('create diffusion model ' + config['sd_path'] + ' successfully!')
for param in flow_model.parameters():
param.requires_grad = False
for param in sod_model.parameters():
param.requires_grad = False
for param in controlnet.parameters():
param.requires_grad = False
for param in pipe.unet.parameters():
param.requires_grad = False
return pipe, frescoProc, controlnet, detector, flow_model, sod_model
def apply_control(x, detector, config):
if config['controlnet_type'] == 'depth':
detected_map, _ = detector(x)
elif config['controlnet_type'] == 'canny':
detected_map = detector(x, 50, 100)
else:
detected_map = detector(x)
return detected_map
def run_keyframe_translation(config):
pipe, frescoProc, controlnet, detector, flow_model, sod_model = get_models(config)
device = pipe._execution_device
guidance_scale = 7.5
do_classifier_free_guidance = guidance_scale > 1
assert(do_classifier_free_guidance)
timesteps = pipe.scheduler.timesteps
cond_scale = [config['cond_scale']] * config['num_inference_steps']
dilate = Dilate(device=device)
base_prompt = config['prompt']
if 'Realistic' in config['sd_path'] or 'realistic' in config['sd_path']:
a_prompt = ', RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, '
n_prompt = '(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation'
else:
a_prompt = ', best quality, extremely detailed, '
n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing finger, extra digit, fewer digits, cropped, worst quality, low quality'
print('\n' + '=' * 100)
print('key frame selection for \"%s\"...'%(config['file_path']))
video_cap = cv2.VideoCapture(config['file_path'])
frame_num = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
# you can set extra_prompts for individual keyframe
# for example, extra_prompts[38] = ', closed eyes' to specify the person frame38 closes the eyes
extra_prompts = [''] * frame_num
keys = get_keyframe_ind(config['file_path'], frame_num, config['mininterv'], config['maxinterv'])
os.makedirs(config['save_path'], exist_ok=True)
os.makedirs(config['save_path']+'keys', exist_ok=True)
os.makedirs(config['save_path']+'video', exist_ok=True)
sublists = [keys[i:i+config['batch_size']-2] for i in range(2, len(keys), config['batch_size']-2)]
sublists[0].insert(0, keys[0])
sublists[0].insert(1, keys[1])
if len(sublists) > 1 and len(sublists[-1]) < 3:
add_num = 3 - len(sublists[-1])
sublists[-1] = sublists[-2][-add_num:] + sublists[-1]
sublists[-2] = sublists[-2][:-add_num]
if not sublists[-2]:
del sublists[-2]
print('processing %d batches:\nkeyframe indexes'%(len(sublists)), sublists)
print('\n' + '=' * 100)
print('video to video translation...')
batch_ind = 0
propagation_mode = batch_ind > 0
imgs = []
record_latents = []
video_cap = cv2.VideoCapture(config['file_path'])
for i in range(frame_num):
# prepare a batch of frame based on sublists
success, frame = video_cap.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = resize_image(frame, 512)
H, W, C = img.shape
Image.fromarray(img).save(os.path.join(config['save_path'], 'video/%04d.png'%(i)))
if i not in sublists[batch_ind]:
continue
imgs += [img]
if i != sublists[batch_ind][-1]:
continue
print('processing batch [%d/%d] with %d frames'%(batch_ind+1, len(sublists), len(sublists[batch_ind])))
# prepare input
batch_size = len(imgs)
n_prompts = [n_prompt] * len(imgs)
prompts = [base_prompt + a_prompt + extra_prompts[ind] for ind in sublists[batch_ind]]
if propagation_mode: # restore the extra_prompts from previous batch
assert len(imgs) == len(sublists[batch_ind]) + 2
prompts = ref_prompt + prompts
prompt_embeds = pipe._encode_prompt(
prompts,
device,
1,
do_classifier_free_guidance,
n_prompts,
)
imgs_torch = torch.cat([numpy2tensor(img) for img in imgs], dim=0)
edges = torch.cat([numpy2tensor(apply_control(img, detector, config)[:, :, None]) for img in imgs], dim=0)
edges = edges.repeat(1,3,1,1).cuda() * 0.5 + 0.5
if do_classifier_free_guidance:
edges = torch.cat([edges.to(pipe.unet.dtype)] * 2)
if config['use_salinecy']:
saliency = get_saliency(imgs, sod_model, dilate)
else:
saliency = None
# prepare parameters for inter-frame and intra-frame consistency
flows, occs, attn_mask, interattn_paras = get_flow_and_interframe_paras(flow_model, imgs)
correlation_matrix = get_intraframe_paras(pipe, imgs_torch, frescoProc,
prompt_embeds, seed = config['seed'])
'''
Flexible settings for attention:
* Turn off FRESCO-guided attention: frescoProc.controller.disable_controller()
Then you can turn on one specific attention submodule
* Turn on Cross-frame attention: frescoProc.controller.enable_cfattn(attn_mask)
* Turn on Spatial-guided attention: frescoProc.controller.enable_intraattn()
* Turn on Temporal-guided attention: frescoProc.controller.enable_interattn(interattn_paras)
Flexible settings for optimization:
* Turn off Spatial-guided optimization: set optimize_temporal = False in apply_FRESCO_opt()
* Turn off Temporal-guided optimization: set correlation_matrix = [] in apply_FRESCO_opt()
* Turn off FRESCO-guided optimization: disable_FRESCO_opt(pipe)
Flexible settings for background smoothing:
* Turn off background smoothing: set saliency = None in apply_FRESCO_opt()
'''
# Turn on all FRESCO support
frescoProc.controller.enable_controller(interattn_paras=interattn_paras, attn_mask=attn_mask)
apply_FRESCO_opt(pipe, steps = timesteps[:config['end_opt_step']],
flows = flows, occs = occs, correlation_matrix=correlation_matrix,
saliency=saliency, optimize_temporal = True)
gc.collect()
torch.cuda.empty_cache()
# run!
latents = inference(pipe, controlnet, frescoProc,
imgs_torch, prompt_embeds, edges, timesteps,
cond_scale, config['num_inference_steps'], config['num_warmup_steps'],
do_classifier_free_guidance, config['seed'], guidance_scale, config['use_controlnet'],
record_latents, propagation_mode,
flows = flows, occs = occs, saliency=saliency, repeat_noise=True)
gc.collect()
torch.cuda.empty_cache()
with torch.no_grad():
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = torch.clamp(image, -1 , 1)
save_imgs = tensor2numpy(image)
bias = 2 if propagation_mode else 0
for ind, num in enumerate(sublists[batch_ind]):
Image.fromarray(save_imgs[ind+bias]).save(os.path.join(config['save_path'], 'keys/%04d.png'%(num)))
gc.collect()
torch.cuda.empty_cache()
batch_ind += 1
# current batch uses the last frame of the previous batch as ref
ref_prompt= [prompts[0], prompts[-1]]
imgs = [imgs[0], imgs[-1]]
propagation_mode = batch_ind > 0
if batch_ind == len(sublists):
gc.collect()
torch.cuda.empty_cache()
break
return keys
def run_full_video_translation(config, keys):
print('\n' + '=' * 100)
if not config['run_ebsynth']:
print('to translate full video with ebsynth, install ebsynth and run:')
else:
print('translating full video with:')
video_cap = cv2.VideoCapture(config['file_path'])
fps = int(video_cap.get(cv2.CAP_PROP_FPS))
o_video = os.path.join(config['save_path'], 'blend.mp4')
max_process = config['max_process']
save_path = config['save_path']
key_ind = io.StringIO()
for k in keys:
print('%d'%(k), end=' ', file=key_ind)
cmd = (
f'python video_blend.py {save_path} --key keys '
f'--key_ind {key_ind.getvalue()} --output {o_video} --fps {fps} '
f'--n_proc {max_process} -ps')
print('\n```')
print(cmd)
print('```')
if config['run_ebsynth']:
os.system(cmd)
print('\n' + '=' * 100)
print('Done')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('config_path', type=str,
default='./config/config_carturn.yaml',
help='The configuration file.')
opt = parser.parse_args()
print('=' * 100)
print('loading configuration...')
with open(opt.config_path, "r") as f:
config = yaml.safe_load(f)
for name, value in sorted(config.items()):
print('%s: %s' % (str(name), str(value)))
keys = run_keyframe_translation(config)
run_full_video_translation(config, keys)