from model.unet import ScaleAt from model.latentnet import * from diffusion.resample import UniformSampler from diffusion.diffusion import space_timesteps from typing import Tuple from torch.utils.data import DataLoader from config_base import BaseConfig from diffusion import * from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule from model import * from choices import * from multiprocessing import get_context import os from dataset_util import * from torch.utils.data.distributed import DistributedSampler from dataset import LatentDataLoader @dataclass class PretrainConfig(BaseConfig): name: str path: str @dataclass class TrainConfig(BaseConfig): # random seed seed: int = 0 train_mode: TrainMode = TrainMode.diffusion train_cond0_prob: float = 0 train_pred_xstart_detach: bool = True train_interpolate_prob: float = 0 train_interpolate_img: bool = False manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all manipulate_cls: str = None manipulate_shots: int = None manipulate_loss: ManipulateLossType = ManipulateLossType.bce manipulate_znormalize: bool = False manipulate_seed: int = 0 accum_batches: int = 1 autoenc_mid_attn: bool = True batch_size: int = 16 batch_size_eval: int = None beatgans_gen_type: GenerativeType = GenerativeType.ddim beatgans_loss_type: LossType = LossType.mse beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large beatgans_rescale_timesteps: bool = False latent_infer_path: str = None latent_znormalize: bool = False latent_gen_type: GenerativeType = GenerativeType.ddim latent_loss_type: LossType = LossType.mse latent_model_mean_type: ModelMeanType = ModelMeanType.eps latent_model_var_type: ModelVarType = ModelVarType.fixed_large latent_rescale_timesteps: bool = False latent_T_eval: int = 1_000 latent_clip_sample: bool = False latent_beta_scheduler: str = 'linear' beta_scheduler: str = 'linear' data_name: str = '' data_val_name: str = None diffusion_type: str = None dropout: float = 0.1 ema_decay: float = 0.9999 eval_num_images: int = 5_000 eval_every_samples: int = 200_000 eval_ema_every_samples: int = 200_000 fid_use_torch: bool = True fp16: bool = False grad_clip: float = 1 img_size: int = 64 lr: float = 0.0001 optimizer: OptimizerType = OptimizerType.adam weight_decay: float = 0 model_conf: ModelConfig = None model_name: ModelName = None model_type: ModelType = None net_attn: Tuple[int] = None net_beatgans_attn_head: int = 1 # not necessarily the same as the the number of style channels net_beatgans_embed_channels: int = 512 net_resblock_updown: bool = True net_enc_use_time: bool = False net_enc_pool: str = 'adaptivenonzero' net_beatgans_gradient_checkpoint: bool = False net_beatgans_resnet_two_cond: bool = False net_beatgans_resnet_use_zero_module: bool = True net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm net_beatgans_resnet_cond_channels: int = None net_ch_mult: Tuple[int] = None net_ch: int = 64 net_enc_attn: Tuple[int] = None net_enc_k: int = None # number of resblocks for the encoder (half-unet) net_enc_num_res_blocks: int = 2 net_enc_channel_mult: Tuple[int] = None net_enc_grad_checkpoint: bool = False net_autoenc_stochastic: bool = False net_latent_activation: Activation = Activation.silu net_latent_channel_mult: Tuple[int] = (1, 2, 4) net_latent_condition_bias: float = 0 net_latent_dropout: float = 0 net_latent_layers: int = None net_latent_net_last_act: Activation = Activation.none net_latent_net_type: LatentNetType = LatentNetType.none net_latent_num_hid_channels: int = 1024 net_latent_num_time_layers: int = 2 net_latent_skip_layers: Tuple[int] = None net_latent_time_emb_channels: int = 64 net_latent_use_norm: bool = False net_latent_time_last_act: bool = False net_num_res_blocks: int = 2 # number of resblocks for the UNET net_num_input_res_blocks: int = None net_enc_num_cls: int = None num_workers: int = 4 parallel: bool = False postfix: str = '' sample_size: int = 64 sample_every_samples: int = 20_000 save_every_samples: int = 100_000 style_ch: int = 512 T_eval: int = 1_000 T_sampler: str = 'uniform' T: int = 1_000 total_samples: int = 10_000_000 warmup: int = 0 pretrain: PretrainConfig = None continue_from: PretrainConfig = None eval_programs: Tuple[str] = None # if present load the checkpoint from this path instead eval_path: str = None base_dir: str = 'checkpoints' use_cache_dataset: bool = False data_cache_dir: str = os.path.expanduser('~/cache') work_cache_dir: str = os.path.expanduser('~/mycache') # to be overridden name: str = '' def __post_init__(self): self.batch_size_eval = self.batch_size_eval or self.batch_size self.data_val_name = self.data_val_name or self.data_name def scale_up_gpus(self, num_gpus, num_nodes=1): self.eval_ema_every_samples *= num_gpus * num_nodes self.eval_every_samples *= num_gpus * num_nodes self.sample_every_samples *= num_gpus * num_nodes self.batch_size *= num_gpus * num_nodes self.batch_size_eval *= num_gpus * num_nodes return self @property def batch_size_effective(self): return self.batch_size * self.accum_batches @property def fid_cache(self): # we try to use the local dirs to reduce the load over network drives # hopefully, this would reduce the disconnection problems with sshfs return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}' @property def data_path(self): # may use the cache dir path = data_paths[self.data_name] if self.use_cache_dataset and path is not None: path = use_cached_dataset_path( path, f'{self.data_cache_dir}/{self.data_name}') return path @property def logdir(self): return f'{self.base_dir}/{self.name}' @property def generate_dir(self): # we try to use the local dirs to reduce the load over network drives # hopefully, this would reduce the disconnection problems with sshfs return f'{self.work_cache_dir}/gen_images/{self.name}' def _make_diffusion_conf(self, T=None): if self.diffusion_type == 'beatgans': # can use T < self.T for evaluation # follows the guided-diffusion repo conventions # t's are evenly spaced if self.beatgans_gen_type == GenerativeType.ddpm: section_counts = [T] elif self.beatgans_gen_type == GenerativeType.ddim: section_counts = f'ddim{T}' else: raise NotImplementedError() return SpacedDiffusionBeatGansConfig( gen_type=self.beatgans_gen_type, model_type=self.model_type, betas=get_named_beta_schedule(self.beta_scheduler, self.T), model_mean_type=self.beatgans_model_mean_type, model_var_type=self.beatgans_model_var_type, loss_type=self.beatgans_loss_type, rescale_timesteps=self.beatgans_rescale_timesteps, use_timesteps=space_timesteps(num_timesteps=self.T, section_counts=section_counts), fp16=self.fp16, ) else: raise NotImplementedError() def _make_latent_diffusion_conf(self, T=None): # can use T < self.T for evaluation # follows the guided-diffusion repo conventions # t's are evenly spaced if self.latent_gen_type == GenerativeType.ddpm: section_counts = [T] elif self.latent_gen_type == GenerativeType.ddim: section_counts = f'ddim{T}' else: raise NotImplementedError() return SpacedDiffusionBeatGansConfig( train_pred_xstart_detach=self.train_pred_xstart_detach, gen_type=self.latent_gen_type, # latent's model is always ddpm model_type=ModelType.ddpm, # latent shares the beta scheduler and full T betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T), model_mean_type=self.latent_model_mean_type, model_var_type=self.latent_model_var_type, loss_type=self.latent_loss_type, rescale_timesteps=self.latent_rescale_timesteps, use_timesteps=space_timesteps(num_timesteps=self.T, section_counts=section_counts), fp16=self.fp16, ) @property def model_out_channels(self): return 3 def make_T_sampler(self): if self.T_sampler == 'uniform': return UniformSampler(self.T) else: raise NotImplementedError() def make_diffusion_conf(self): return self._make_diffusion_conf(self.T) def make_eval_diffusion_conf(self): return self._make_diffusion_conf(T=self.T_eval) def make_latent_diffusion_conf(self): return self._make_latent_diffusion_conf(T=self.T) def make_latent_eval_diffusion_conf(self): # latent can have different eval T return self._make_latent_diffusion_conf(T=self.latent_T_eval) def make_dataset(self, path=None, **kwargs): return LatentDataLoader(self.window_size, self.frame_jpgs, self.lmd_feats_prefix, self.audio_prefix, self.raw_audio_prefix, self.motion_latents_prefix, self.pose_prefix, self.db_name, audio_hz=self.audio_hz) def make_loader(self, dataset, shuffle: bool, num_worker: bool = None, drop_last: bool = True, batch_size: int = None, parallel: bool = False): if parallel and distributed.is_initialized(): # drop last to make sure that there is no added special indexes sampler = DistributedSampler(dataset, shuffle=shuffle, drop_last=True) else: sampler = None return DataLoader( dataset, batch_size=batch_size or self.batch_size, sampler=sampler, # with sampler, use the sample instead of this option shuffle=False if sampler else shuffle, num_workers=num_worker or self.num_workers, pin_memory=True, drop_last=drop_last, multiprocessing_context=get_context('fork'), ) def make_model_conf(self): if self.model_name == ModelName.beatgans_ddpm: self.model_type = ModelType.ddpm self.model_conf = BeatGANsUNetConfig( attention_resolutions=self.net_attn, channel_mult=self.net_ch_mult, conv_resample=True, dims=2, dropout=self.dropout, embed_channels=self.net_beatgans_embed_channels, image_size=self.img_size, in_channels=3, model_channels=self.net_ch, num_classes=None, num_head_channels=-1, num_heads_upsample=-1, num_heads=self.net_beatgans_attn_head, num_res_blocks=self.net_num_res_blocks, num_input_res_blocks=self.net_num_input_res_blocks, out_channels=self.model_out_channels, resblock_updown=self.net_resblock_updown, use_checkpoint=self.net_beatgans_gradient_checkpoint, use_new_attention_order=False, resnet_two_cond=self.net_beatgans_resnet_two_cond, resnet_use_zero_module=self. net_beatgans_resnet_use_zero_module, ) elif self.model_name in [ ModelName.beatgans_autoenc, ]: cls = BeatGANsAutoencConfig # supports both autoenc and vaeddpm if self.model_name == ModelName.beatgans_autoenc: self.model_type = ModelType.autoencoder else: raise NotImplementedError() if self.net_latent_net_type == LatentNetType.none: latent_net_conf = None elif self.net_latent_net_type == LatentNetType.skip: latent_net_conf = MLPSkipNetConfig( num_channels=self.style_ch, skip_layers=self.net_latent_skip_layers, num_hid_channels=self.net_latent_num_hid_channels, num_layers=self.net_latent_layers, num_time_emb_channels=self.net_latent_time_emb_channels, activation=self.net_latent_activation, use_norm=self.net_latent_use_norm, condition_bias=self.net_latent_condition_bias, dropout=self.net_latent_dropout, last_act=self.net_latent_net_last_act, num_time_layers=self.net_latent_num_time_layers, time_last_act=self.net_latent_time_last_act, ) else: raise NotImplementedError() self.model_conf = cls( attention_resolutions=self.net_attn, channel_mult=self.net_ch_mult, conv_resample=True, dims=2, dropout=self.dropout, embed_channels=self.net_beatgans_embed_channels, enc_out_channels=self.style_ch, enc_pool=self.net_enc_pool, enc_num_res_block=self.net_enc_num_res_blocks, enc_channel_mult=self.net_enc_channel_mult, enc_grad_checkpoint=self.net_enc_grad_checkpoint, enc_attn_resolutions=self.net_enc_attn, image_size=self.img_size, in_channels=3, model_channels=self.net_ch, num_classes=None, num_head_channels=-1, num_heads_upsample=-1, num_heads=self.net_beatgans_attn_head, num_res_blocks=self.net_num_res_blocks, num_input_res_blocks=self.net_num_input_res_blocks, out_channels=self.model_out_channels, resblock_updown=self.net_resblock_updown, use_checkpoint=self.net_beatgans_gradient_checkpoint, use_new_attention_order=False, resnet_two_cond=self.net_beatgans_resnet_two_cond, resnet_use_zero_module=self. net_beatgans_resnet_use_zero_module, latent_net_conf=latent_net_conf, resnet_cond_channels=self.net_beatgans_resnet_cond_channels, ) else: raise NotImplementedError(self.model_name) return self.model_conf