from pathlib import Path import PIL from tqdm import tqdm from accelerate import Accelerator from datasets import load_dataset from diffusers import DDPMPipeline, UNet2DModel, DDPMScheduler from diffusers.optimization import get_cosine_schedule_with_warmup from diffusers.utils import make_image_grid from huggingface_hub import create_repo, upload_folder from peft import LoraConfig, get_peft_model import torch import torch.nn.functional as F from torchvision import transforms from config import TrainingConfig """ Or diffusion for simple images and explore subtly different x_T's and what the output is. Denoise each x_T multiple times to get a better picture of the distribution. Maybe use a set sequence of seeds for every denoising run (torch.Generator(seed=__)). Inter-concept space. Conciousness. """ def evaluate(config, epoch, pipeline): # Sample some images from random noise (this is the backward diffusion process). # The default pipeline output type is `List[PIL.Image]` images = pipeline( batch_size=config.eval_batch_size, generator=torch.manual_seed(config.seed), num_inference_steps=50 ).images # Make a grid out of the images image_grid = make_image_grid(images, rows=2, cols=2) # Save the images test_dir = Path(config.output_dir) / 'samples' test_dir.mkdir(exist_ok=True) image_grid.save(test_dir / f'{epoch:04d}.png') def print_trainable_parameters(model): trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() print( f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}" ) if __name__ == '__main__': config = TrainingConfig() config.dataset_name = 'keremberke/painting-style-classification' ds_dict = load_dataset(config.dataset_name, name='full') preprocess = transforms.Compose([ transforms.Resize((config.image_size, config.image_size)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) def transform(examples): return { 'images': [preprocess(img.convert('RGB')) for img in examples['image']] } ds_dict.set_transform(transform) # automatically applies preprocessing to samples as we load them train_dataloader = torch.utils.data.DataLoader(ds_dict['train'], batch_size=config.train_batch_size, shuffle=True) valid_dataloader = torch.utils.data.DataLoader(ds_dict['validation'], batch_size=config.eval_batch_size, shuffle=False) test_dataloader = torch.utils.data.DataLoader(ds_dict['test'], batch_size=config.eval_batch_size, shuffle=False) """ unet = UNet2DModel.from_pretrained( 'google/ddpm-celebahq-256' ).to('mps') scheduler = DDPMScheduler.from_pretrained( 'google/ddpm-celebahq-256' ) """ """ unet = UNet2DModel.from_pretrained( 'jmemon/ddpm-paintings-128-finetuned-celebahq', use_safetensors=True ).to('mps') scheduler = DDPMScheduler.from_pretrained( 'jmemon/ddpm-paintings-128-finetuned-celebahq' ) """ unet = UNet2DModel.from_pretrained( str(Path(__file__).parent / 'unet'), use_safetensors=True ).to('mps') scheduler = DDPMScheduler.from_pretrained( str(Path(__file__).parent / 'scheduler') ) lora_config = LoraConfig( r=8, lora_alpha=8, target_modules=['to_k','to_v'], lora_dropout=0.1, bias='none') lora_unet = get_peft_model(unet, lora_config) print_trainable_parameters(lora_unet) optimizer = torch.optim.AdamW(lora_unet.parameters(), lr=config.learning_rate) lr_scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps, num_training_steps=(len(train_dataloader) * config.num_epochs) ) accelerator = Accelerator( gradient_accumulation_steps=config.gradient_accumulation_steps, mixed_precision=config.mixed_precision, log_with='tensorboard', project_dir=Path(config.output_dir) / 'logs' ) if accelerator.is_main_process: if config.push_to_hub: repo_id = create_repo(repo_id=config.hub_model_id, exist_ok=True).repo_id accelerator.init_trackers('ddpm-paintings-128-finetuned-celebahq') global_step = 0 for epoch in range(6, config.num_epochs + 6): pbar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) pbar.set_description(f'Epoch {epoch}') for idx, batch in enumerate(train_dataloader): clean_images = batch['images'].to('mps') noise = torch.randn(clean_images.shape, device=clean_images.device) bs = clean_images.shape[0] ts = torch.randint(0, scheduler.config.num_train_timesteps, (bs,), device=clean_images.device, dtype=torch.int64) noisy_images = scheduler.add_noise(clean_images, noise, ts) with accelerator.accumulate(lora_unet): noise_pred = lora_unet(noisy_images, ts, return_dict=False)[0] loss = F.mse_loss(noise_pred, noise) accelerator.backward(loss) accelerator.clip_grad_norm_(lora_unet.parameters(), 1.0) optimizer.step() lr_scheduler.step() optimizer.zero_grad() logs = {'loss': loss.detach().item(), 'lr': lr_scheduler.get_last_lr()[0], 'step': global_step} pbar.update(1) pbar.set_postfix(loss=logs['loss'], step=idx + 1) accelerator.log(logs, step=global_step) global_step += 1 pbar.close() if accelerator.is_main_process: #pipeline = DDPMPipeline(unet=accelerator.unwrap_model(lora_unet).merge_and_unload(), scheduler=scheduler) pipeline = DDPMPipeline(unet=accelerator.unwrap_model(lora_unet), scheduler=scheduler) if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: # Save some images for model trained at end of epoch evaluate(config, epoch, pipeline) if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: _pipeline = DDPMPipeline( unet=accelerator.unwrap_model(lora_unet).merge_and_unload(), scheduler=scheduler) if config.push_to_hub: _pipeline.save_pretrained( config.output_dir, push_to_hub=True, repo_id=repo_id, token='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' ) upload_folder( repo_id=repo_id, folder_path=config.output_dir, commit_message=f'Epoch {epoch}', ignore_patterns=['logs/*', '*/.DS_Store'], token='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' ) model_loc = 'jmemon/ddpm-paintings-128-finetuned-celebahq' else: _pipeline.save_pretrained(config.output_dir) model_loc = str(Path(__file__).parent / 'diffusion_model_pytorch.safetensors') unet = UNet2DModel.from_pretrained(model_loc, use_safetensors=True) lora_unet = get_peft_model(unet, lora_config)