xierui.0097
Add application file
f0e9666
raw
history blame
4.84 kB
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import os.path as osp
from datetime import datetime
import torch
from easydict import EasyDict
cfg = EasyDict(__name__='Config: VideoLDM Decoder')
# ---------------------------work dir--------------------------
cfg.work_dir = 'workspace/'
# ---------------------------Global Variable-----------------------------------
cfg.resolution = [448, 256]
cfg.max_frames = 32
# -----------------------------------------------------------------------------
# ---------------------------Dataset Parameter---------------------------------
cfg.mean = [0.5, 0.5, 0.5]
cfg.std = [0.5, 0.5, 0.5]
cfg.max_words = 1000
# PlaceHolder
cfg.vit_out_dim = 1024
cfg.vit_resolution = [224, 224]
cfg.depth_clamp = 10.0
cfg.misc_size = 384
cfg.depth_std = 20.0
cfg.frame_lens = 32
cfg.sample_fps = 8
cfg.batch_sizes = 1
# -----------------------------------------------------------------------------
# ---------------------------Mode Parameters-----------------------------------
# Diffusion
cfg.schedule = 'cosine'
cfg.num_timesteps = 1000
cfg.mean_type = 'v'
cfg.var_type = 'fixed_small'
cfg.loss_type = 'mse'
cfg.ddim_timesteps = 50
cfg.ddim_eta = 0.0
cfg.clamp = 1.0
cfg.share_noise = False
cfg.use_div_loss = False
cfg.noise_strength = 0.1
# classifier-free guidance
cfg.p_zero = 0.1
cfg.guide_scale = 3.0
# clip vision encoder
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
# Model
cfg.scale_factor = 0.18215
cfg.use_fp16 = True
cfg.temporal_attention = True
cfg.decoder_bs = 8
cfg.UNet = {
'type': 'Vid2VidSDUNet',
'in_dim': 4,
'dim': 320,
'y_dim': cfg.vit_out_dim,
'context_dim': 1024,
'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
'dim_mult': [1, 2, 4, 4],
'num_heads': 8,
'head_dim': 64,
'num_res_blocks': 2,
'attn_scales': [1 / 1, 1 / 2, 1 / 4],
'dropout': 0.1,
'temporal_attention': cfg.temporal_attention,
'temporal_attn_times': 1,
'use_checkpoint': False,
'use_fps_condition': False,
'use_sim_mask': False,
'num_tokens': 4,
'default_fps': 8,
'input_dim': 1024
}
cfg.guidances = []
# auotoencoder from stabel diffusion
cfg.auto_encoder = {
'type': 'AutoencoderKL',
'ddconfig': {
'double_z': True,
'z_channels': 4,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 128,
'ch_mult': [1, 2, 4, 4],
'num_res_blocks': 2,
'attn_resolutions': [],
'dropout': 0.0
},
'embed_dim': 4,
'pretrained': 'models/v2-1_512-ema-pruned.ckpt'
}
# clip embedder
cfg.embedder = {
'type': 'FrozenOpenCLIPEmbedder',
'layer': 'penultimate',
'vit_resolution': [224, 224],
'pretrained': 'open_clip_pytorch_model.bin'
}
# -----------------------------------------------------------------------------
# ---------------------------Training Settings---------------------------------
# training and optimizer
cfg.ema_decay = 0.9999
cfg.num_steps = 600000
cfg.lr = 5e-5
cfg.weight_decay = 0.0
cfg.betas = (0.9, 0.999)
cfg.eps = 1.0e-8
cfg.chunk_size = 16
cfg.alpha = 0.7
cfg.save_ckp_interval = 1000
# -----------------------------------------------------------------------------
# ----------------------------Pretrain Settings---------------------------------
# Default: load 2d pretrain
cfg.fix_weight = False
cfg.load_match = False
cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
# -----------------------------------------------------------------------------
# -----------------------------Visual-------------------------------------------
# Visual videos
cfg.viz_interval = 1000
cfg.visual_train = {
'type': 'VisualVideoTextDuringTrain',
}
cfg.visual_inference = {
'type': 'VisualGeneratedVideos',
}
cfg.inference_list_path = ''
# logging
cfg.log_interval = 100
# Default log_dir
cfg.log_dir = 'workspace/output_data'
# -----------------------------------------------------------------------------
# ---------------------------Others--------------------------------------------
# seed
cfg.seed = 8888
cfg.negative_prompt = 'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \
CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \
signature, jpeg artifacts, deformed, lowres, over-smooth'
cfg.positive_prompt = 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \
hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \
skin pore detailing, hyper sharpness, perfect without deformations.'