from __future__ import annotations import os import pathlib import shlex import shutil import subprocess import gradio as gr import PIL.Image import torch os.environ['PYTHONPATH'] = f'lora:{os.getenv("PYTHONPATH", "")}' def pad_image(image: PIL.Image.Image) -> PIL.Image.Image: w, h = image.size if w == h: return image elif w > h: new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0)) new_image.paste(image, (0, (w - h) // 2)) return new_image else: new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0)) new_image.paste(image, ((h - w) // 2, 0)) return new_image class Trainer: def __init__(self): self.is_running = False self.is_running_message = 'Another training is in progress.' self.instance_data_dir = pathlib.Path('training_data') self.output_dir = pathlib.Path('results') def check_if_running(self) -> dict: if self.is_running: return gr.update(value=self.is_running_message) else: return gr.update(value='No training is running.') def cleanup_dirs(self) -> None: shutil.rmtree(self.instance_data_dir, ignore_errors=True) shutil.rmtree(self.output_dir, ignore_errors=True) def prepare_dataset(self, concept_images: list, resolution: int) -> None: self.instance_data_dir.mkdir() for i, temp_path in enumerate(concept_images): image = PIL.Image.open(temp_path.name) image = pad_image(image) image = image.resize((resolution, resolution)) image = image.convert('RGB') out_path = self.instance_data_dir / f'{i:03d}.jpg' image.save(out_path, format='JPEG', quality=100) def run( self, base_model: str, resolution_s: str, concept_images: list | None, concept_prompt: str, n_steps: int, learning_rate: float, ) -> tuple[dict, str]: if not torch.cuda.is_available(): raise gr.Error('CUDA is not available.') out_path = '' if self.is_running: return gr.update(value=self.is_running_message), out_path if concept_images is None: raise gr.Error('You need to upload images.') if not concept_prompt: raise gr.Error('The concept prompt is missing.') resolution = int(resolution_s) self.cleanup_dirs() self.prepare_dataset(concept_images, resolution) self.is_running = True command = f''' accelerate launch lora/train_lora_dreambooth.py \ --pretrained_model_name_or_path={base_model} \ --instance_data_dir={self.instance_data_dir} \ --output_dir={self.output_dir} \ --instance_prompt="style of {concept_prompt}" \ --resolution={resolution} \ --train_batch_size=1 \ --gradient_accumulation_steps=1 \ --learning_rate={learning_rate} \ --lr_scheduler=constant \ --lr_warmup_steps=0 \ --max_train_steps={n_steps} ''' res = subprocess.run(shlex.split(command)) self.is_running = False if res.returncode == 0: result_message = 'Training Completed!' weight_path = self.output_dir / 'lora_weight.pt' if weight_path.exists(): out_path = weight_path.as_posix() else: result_message = 'Training Failed!' return gr.update(value=result_message), out_path