|
import pytorch_lightning as pl |
|
from torch.utils.data import DataLoader |
|
from datasets.ytb_vos import YoutubeVOSDataset |
|
from datasets.ytb_vis import YoutubeVISDataset |
|
from datasets.saliency_modular import SaliencyDataset |
|
from datasets.vipseg import VIPSegDataset |
|
from datasets.mvimagenet import MVImageNetDataset |
|
from datasets.sam import SAMDataset |
|
from datasets.uvo import UVODataset |
|
from datasets.uvo_val import UVOValDataset |
|
from datasets.mose import MoseDataset |
|
from datasets.vitonhd import VitonHDDataset |
|
from datasets.fashiontryon import FashionTryonDataset |
|
from datasets.lvis import LvisDataset |
|
from cldm.logger import ImageLogger |
|
from cldm.model import create_model, load_state_dict |
|
from torch.utils.data import ConcatDataset |
|
from cldm.hack import disable_verbosity, enable_sliced_attention |
|
from omegaconf import OmegaConf |
|
|
|
save_memory = False |
|
disable_verbosity() |
|
if save_memory: |
|
enable_sliced_attention() |
|
|
|
|
|
resume_path = 'path/to/weight' |
|
batch_size = 16 |
|
logger_freq = 1000 |
|
learning_rate = 1e-5 |
|
sd_locked = False |
|
only_mid_control = False |
|
n_gpus = 2 |
|
accumulate_grad_batches=1 |
|
|
|
|
|
model = create_model('./configs/anydoor.yaml').cpu() |
|
model.load_state_dict(load_state_dict(resume_path, location='cpu')) |
|
model.learning_rate = learning_rate |
|
model.sd_locked = sd_locked |
|
model.only_mid_control = only_mid_control |
|
|
|
|
|
DConf = OmegaConf.load('./configs/datasets.yaml') |
|
dataset1 = YoutubeVOSDataset(**DConf.Train.YoutubeVOS) |
|
dataset2 = SaliencyDataset(**DConf.Train.Saliency) |
|
dataset3 = VIPSegDataset(**DConf.Train.VIPSeg) |
|
dataset4 = YoutubeVISDataset(**DConf.Train.YoutubeVIS) |
|
dataset5 = MVImageNetDataset(**DConf.Train.MVImageNet) |
|
dataset6 = SAMDataset(**DConf.Train.SAM) |
|
dataset7 = UVODataset(**DConf.Train.UVO.train) |
|
dataset8 = VitonHDDataset(**DConf.Train.VitonHD) |
|
dataset9 = UVOValDataset(**DConf.Train.UVO.val) |
|
dataset10 = MoseDataset(**DConf.Train.Mose) |
|
dataset11 = FashionTryonDataset(**DConf.Train.FashionTryon) |
|
dataset12 = LvisDataset(**DConf.Train.Lvis) |
|
|
|
image_data = [dataset2, dataset6, dataset12] |
|
video_data = [dataset1, dataset3, dataset4, dataset7, dataset9, dataset10 ] |
|
tryon_data = [dataset8, dataset11] |
|
threed_data = [dataset5] |
|
|
|
|
|
dataset = ConcatDataset( image_data + video_data + tryon_data + threed_data + video_data + tryon_data + threed_data ) |
|
dataloader = DataLoader(dataset, num_workers=8, batch_size=batch_size, shuffle=True) |
|
logger = ImageLogger(batch_frequency=logger_freq) |
|
trainer = pl.Trainer(gpus=n_gpus, strategy="ddp", precision=16, accelerator="gpu", callbacks=[logger], progress_bar_refresh_rate=1, accumulate_grad_batches=accumulate_grad_batches) |
|
|
|
|
|
trainer.fit(model, dataloader) |
|
|