File size: 7,516 Bytes
37aeb5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import torch
from accelerate import Accelerator
from accelerate.logging import MultiProcessAdapter
from dataclasses import dataclass, field
from typing import Optional, Union
from datasets import load_dataset
import json
import abc
from diffusers.utils import make_image_grid
import numpy as np
import wandb
from custum_3d_diffusion.trainings.utils import load_config
from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig
class BasicTrainer(torch.nn.Module, abc.ABC):
accelerator: Accelerator
logger: MultiProcessAdapter
unet: ConfigurableUNet2DConditionModel
train_dataloader: torch.utils.data.DataLoader
test_dataset: torch.utils.data.Dataset
attn_config: AttnConfig
@dataclass
class TrainerConfig:
trainer_name: str = "basic"
pretrained_model_name_or_path: str = ""
attn_config: dict = field(default_factory=dict)
dataset_name: str = ""
dataset_config_name: Optional[str] = None
resolution: str = "1024"
dataloader_num_workers: int = 4
pair_sampler_group_size: int = 1
num_views: int = 4
max_train_steps: int = -1 # -1 means infinity, otherwise [0, max_train_steps)
training_step_interval: int = 1 # train on step i*interval, stop at max_train_steps
max_train_samples: Optional[int] = None
seed: Optional[int] = None # For dataset related operations and validation stuff
train_batch_size: int = 1
validation_interval: int = 5000
debug: bool = False
cfg: TrainerConfig # only enable_xxx is used
def __init__(
self,
accelerator: Accelerator,
logger: MultiProcessAdapter,
unet: ConfigurableUNet2DConditionModel,
config: Union[dict, str],
weight_dtype: torch.dtype,
index: int,
):
super().__init__()
self.index = index # index in all trainers
self.accelerator = accelerator
self.logger = logger
self.unet = unet
self.weight_dtype = weight_dtype
self.ext_logs = {}
self.cfg = load_config(self.TrainerConfig, config)
self.attn_config = load_config(AttnConfig, self.cfg.attn_config)
self.test_dataset = None
self.validate_trainer_config()
self.configure()
def get_HW(self):
resolution = json.loads(self.cfg.resolution)
if isinstance(resolution, int):
H = W = resolution
elif isinstance(resolution, list):
H, W = resolution
return H, W
def unet_update(self):
self.unet.update_config(self.attn_config)
def validate_trainer_config(self):
pass
def is_train_finished(self, current_step):
assert isinstance(self.cfg.max_train_steps, int)
return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps
def next_train_step(self, current_step):
if self.is_train_finished(current_step):
return None
return current_step + self.cfg.training_step_interval
@classmethod
def make_image_into_grid(cls, all_imgs, rows=2, columns=2):
catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)]
return make_image_grid(catted, rows=1, cols=len(catted))
def configure(self) -> None:
pass
@abc.abstractmethod
def init_shared_modules(self, shared_modules: dict) -> dict:
pass
def load_dataset(self):
dataset = load_dataset(
self.cfg.dataset_name,
self.cfg.dataset_config_name,
trust_remote_code=True
)
return dataset
@abc.abstractmethod
def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
"""Both init train_dataloader and test_dataset, but returns train_dataloader only"""
pass
@abc.abstractmethod
def forward_step(
self,
*args,
**kwargs
) -> torch.Tensor:
"""
input a batch
return a loss
"""
self.unet_update()
pass
@abc.abstractmethod
def construct_pipeline(self, shared_modules, unet):
pass
@abc.abstractmethod
def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
"""
For inference time forward.
"""
pass
@abc.abstractmethod
def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
pass
def do_validation(
self,
shared_modules,
unet,
global_step,
):
self.unet_update()
self.logger.info("Running validation... ")
pipeline = self.construct_pipeline(shared_modules, unet)
pipeline.set_progress_bar_config(disable=True)
titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.])
for tracker in self.accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
elif tracker.name == "wandb":
[image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] # inplace operation
tracker.log({"validation": [
wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg")
for i, image in enumerate(images)]})
else:
self.logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
return images
@torch.no_grad()
def log_validation(
self,
shared_modules,
unet,
global_step,
force=False
):
if self.accelerator.is_main_process:
for tracker in self.accelerator.trackers:
if tracker.name == "wandb":
tracker.log(self.ext_logs)
self.ext_logs = {}
if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force:
self.unet_update()
if self.accelerator.is_main_process:
self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step)
def save_model(self, unwrap_unet, shared_modules, save_dir):
if self.accelerator.is_main_process:
pipeline = self.construct_pipeline(shared_modules, unwrap_unet)
pipeline.save_pretrained(save_dir)
self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}")
def save_debug_info(self, save_name="debug", **kwargs):
if self.cfg.debug:
to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()}
import pickle
import os
if os.path.exists(f"{save_name}.pkl"):
for i in range(100):
if not os.path.exists(f"{save_name}_v{i}.pkl"):
save_name = f"{save_name}_v{i}"
break
with open(f"{save_name}.pkl", "wb") as f:
pickle.dump(to_saves, f) |