RealFill-Training-UI / trainer.py
thuanz123's picture
Update trainer.py
1c5be4e
from __future__ import annotations
import datetime
import os
import pathlib
import shlex
import shutil
import subprocess
import gradio as gr
import slugify
import torch
from PIL import Image
from huggingface_hub import HfApi
from app_upload import ModelUploader
from utils import save_model_card
URL_TO_JOIN_LIBRARY_ORG = 'https://huggingface.co/organizations/realfill-library/share/WctmaLvDHWxnuWoJxagTrzVXbGwxoqoJoG'
class Trainer:
def __init__(self, hf_token: str | None = None):
self.hf_token = hf_token
self.api = HfApi(token=hf_token)
self.model_uploader = ModelUploader(hf_token)
def prepare_dataset(self, reference_images: list,
target_image: Image.Image, target_mask: Image.Image,
train_data_dir: pathlib.Path, output_dir: pathlib.Path) -> None:
shutil.rmtree(train_data_dir, ignore_errors=True)
train_data_dir.mkdir(parents=True)
(train_data_dir / 'ref').mkdir(parents=True)
(train_data_dir / 'target').mkdir(parents=True)
for i, temp_path in enumerate(reference_images):
image = Image.open(temp_path.name)
image = image.convert('RGB')
out_path = train_data_dir / 'ref' / f'{i:03d}.jpg'
image.save(out_path, format='JPEG', quality=100)
target_image = Image.open(target_image[0].name)
target_image = target_image.convert('RGB')
out_path = train_data_dir / 'target' / f'target.jpg'
target_image.save(out_path, format='JPEG', quality=100)
out_path = output_dir / f'target.jpg'
target_image.save(out_path, format='JPEG', quality=100)
target_mask = Image.open(target_mask[0].name)
target_mask = target_mask.convert('L')
out_path = train_data_dir / 'target' / f'mask.jpg'
target_mask.save(out_path, format='JPEG', quality=100)
out_path = output_dir / f'mask.jpg'
target_mask.save(out_path, format='JPEG', quality=100)
def join_library_org(self) -> None:
subprocess.run(
shlex.split(
f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_LIBRARY_ORG}'
))
def run(
self,
reference_images: list | None,
target_image: Image.Image | None,
target_mask: Image.Image | None,
output_model_name: str,
overwrite_existing_model: bool,
base_model: str,
resolution_s: str,
n_steps: int,
unet_learning_rate: float,
text_encoder_learning_rate: float,
lora_rank: int,
lora_dropout: float,
lora_alpha: int,
gradient_accumulation: int,
seed: int,
fp16: bool,
use_8bit_adam: bool,
checkpointing_steps: int,
use_wandb: bool,
validation_steps: int,
upload_to_hub: bool,
use_private_repo: bool,
delete_existing_repo: bool,
upload_to: str,
remove_gpu_after_training: bool,
) -> str:
if not torch.cuda.is_available():
raise gr.Error('CUDA is not available.')
if reference_images is None:
raise gr.Error('You need to upload reference images.')
if target_image is None:
raise gr.Error('The instance prompt is missing.')
resolution = int(resolution_s)
if not output_model_name:
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
output_model_name = f'realfill-{timestamp}'
output_model_name = slugify.slugify(output_model_name)
repo_dir = pathlib.Path(__file__).parent
output_dir = repo_dir / 'experiments' / output_model_name
if overwrite_existing_model or upload_to_hub:
shutil.rmtree(output_dir, ignore_errors=True)
output_dir.mkdir(parents=True)
train_data_dir = repo_dir / 'training_data' / output_model_name
self.prepare_dataset(reference_images, target_image, target_mask, train_data_dir, output_dir)
if upload_to_hub:
self.join_library_org()
command = f'''
python train_realfill.py \
--pretrained_model_name_or_path={base_model} \
--train_data_dir={train_data_dir} \
--output_dir={output_dir} \
--resolution={resolution} \
--train_batch_size=16 \
--gradient_accumulation_steps={gradient_accumulation} --gradient_checkpointing \
--unet_learning_rate={unet_learning_rate} \
--text_encoder_learning_rate={text_encoder_learning_rate} \
--lr_scheduler=constant \
--lr_warmup_steps=100 \
--set_grads_to_none \
--max_train_steps={n_steps} \
--checkpointing_steps={checkpointing_steps} \
--validation_steps={validation_steps} \
--lora_rank={lora_rank} \
--lora_dropout={lora_dropout} \
--lora_alpha={lora_alpha} \
--seed={seed}
'''
if fp16:
command += ' --mixed_precision fp16'
if use_8bit_adam:
command += ' --use_8bit_adam'
if use_wandb:
command += ' --report_to wandb'
with open(output_dir / 'train.sh', 'w') as f:
command_s = ' '.join(command.split())
f.write(command_s)
subprocess.run(shlex.split(command))
save_model_card(save_dir=output_dir,
base_model=base_model,
target_image=output_dir / 'target.jpg',
target_mask=output_dir / 'mask.jpg')
message = 'Training completed!'
print(message)
if upload_to_hub:
upload_message = self.model_uploader.upload_model(
folder_path=output_dir.as_posix(),
repo_name=output_model_name,
upload_to=upload_to,
private=use_private_repo,
delete_existing_repo=delete_existing_repo)
print(upload_message)
message = message + '\n' + upload_message
if remove_gpu_after_training:
space_id = os.getenv('SPACE_ID')
if space_id:
self.api.request_space_hardware(repo_id=space_id,
hardware='cpu-basic')
return message