from __future__ import annotations import datetime import os import pathlib import shlex import shutil import subprocess import gradio as gr import slugify import torch from PIL import Image from huggingface_hub import HfApi from app_upload import ModelUploader from utils import save_model_card URL_TO_JOIN_LIBRARY_ORG = 'https://huggingface.co/organizations/realfill-library/share/WctmaLvDHWxnuWoJxagTrzVXbGwxoqoJoG' class Trainer: def __init__(self, hf_token: str | None = None): self.hf_token = hf_token self.api = HfApi(token=hf_token) self.model_uploader = ModelUploader(hf_token) def prepare_dataset(self, reference_images: list, target_image: Image.Image, target_mask: Image.Image, train_data_dir: pathlib.Path, output_dir: pathlib.Path) -> None: shutil.rmtree(train_data_dir, ignore_errors=True) train_data_dir.mkdir(parents=True) (train_data_dir / 'ref').mkdir(parents=True) (train_data_dir / 'target').mkdir(parents=True) for i, temp_path in enumerate(reference_images): image = Image.open(temp_path.name) image = image.convert('RGB') out_path = train_data_dir / 'ref' / f'{i:03d}.jpg' image.save(out_path, format='JPEG', quality=100) target_image = Image.open(target_image[0].name) target_image = target_image.convert('RGB') out_path = train_data_dir / 'target' / f'target.jpg' target_image.save(out_path, format='JPEG', quality=100) out_path = output_dir / f'target.jpg' target_image.save(out_path, format='JPEG', quality=100) target_mask = Image.open(target_mask[0].name) target_mask = target_mask.convert('L') out_path = train_data_dir / 'target' / f'mask.jpg' target_mask.save(out_path, format='JPEG', quality=100) out_path = output_dir / f'mask.jpg' target_mask.save(out_path, format='JPEG', quality=100) def join_library_org(self) -> None: subprocess.run( shlex.split( f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_LIBRARY_ORG}' )) def run( self, reference_images: list | None, target_image: Image.Image | None, target_mask: Image.Image | None, output_model_name: str, overwrite_existing_model: bool, base_model: str, resolution_s: str, n_steps: int, unet_learning_rate: float, text_encoder_learning_rate: float, lora_rank: int, lora_dropout: float, lora_alpha: int, gradient_accumulation: int, seed: int, fp16: bool, use_8bit_adam: bool, checkpointing_steps: int, use_wandb: bool, validation_steps: int, upload_to_hub: bool, use_private_repo: bool, delete_existing_repo: bool, upload_to: str, remove_gpu_after_training: bool, ) -> str: if not torch.cuda.is_available(): raise gr.Error('CUDA is not available.') if reference_images is None: raise gr.Error('You need to upload reference images.') if target_image is None: raise gr.Error('The instance prompt is missing.') resolution = int(resolution_s) if not output_model_name: timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') output_model_name = f'realfill-{timestamp}' output_model_name = slugify.slugify(output_model_name) repo_dir = pathlib.Path(__file__).parent output_dir = repo_dir / 'experiments' / output_model_name if overwrite_existing_model or upload_to_hub: shutil.rmtree(output_dir, ignore_errors=True) output_dir.mkdir(parents=True) train_data_dir = repo_dir / 'training_data' / output_model_name self.prepare_dataset(reference_images, target_image, target_mask, train_data_dir, output_dir) if upload_to_hub: self.join_library_org() command = f''' python train_realfill.py \ --pretrained_model_name_or_path={base_model} \ --train_data_dir={train_data_dir} \ --output_dir={output_dir} \ --resolution={resolution} \ --train_batch_size=16 \ --gradient_accumulation_steps={gradient_accumulation} --gradient_checkpointing \ --unet_learning_rate={unet_learning_rate} \ --text_encoder_learning_rate={text_encoder_learning_rate} \ --lr_scheduler=constant \ --lr_warmup_steps=100 \ --set_grads_to_none \ --max_train_steps={n_steps} \ --checkpointing_steps={checkpointing_steps} \ --validation_steps={validation_steps} \ --lora_rank={lora_rank} \ --lora_dropout={lora_dropout} \ --lora_alpha={lora_alpha} \ --seed={seed} ''' if fp16: command += ' --mixed_precision fp16' if use_8bit_adam: command += ' --use_8bit_adam' if use_wandb: command += ' --report_to wandb' with open(output_dir / 'train.sh', 'w') as f: command_s = ' '.join(command.split()) f.write(command_s) subprocess.run(shlex.split(command)) save_model_card(save_dir=output_dir, base_model=base_model, target_image=output_dir / 'target.jpg', target_mask=output_dir / 'mask.jpg') message = 'Training completed!' print(message) if upload_to_hub: upload_message = self.model_uploader.upload_model( folder_path=output_dir.as_posix(), repo_name=output_model_name, upload_to=upload_to, private=use_private_repo, delete_existing_repo=delete_existing_repo) print(upload_message) message = message + '\n' + upload_message if remove_gpu_after_training: space_id = os.getenv('SPACE_ID') if space_id: self.api.request_space_hardware(repo_id=space_id, hardware='cpu-basic') return message