RealFill-Training-UI / app_training.py
thuanz123's picture
Upload 18 files
ffa45e8
#!/usr/bin/env python
from __future__ import annotations
import os
import gradio as gr
from constants import UploadTarget
from inference import InferencePipeline
from trainer import Trainer
def create_training_demo(trainer: Trainer,
pipe: InferencePipeline | None = None) -> gr.Blocks:
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown('Training Data')
reference_images = gr.Files(label='Reference images')
target_image = gr.Files(label='Target image')
target_mask = gr.Files(label='Target mask')
gr.Markdown('''
- Upload reference images of the scene you are planning on training on.
- For the target image, the inpainting region should be white.
- For the target mask, white for inpainting and black for keeping as is.
''')
with gr.Box():
gr.Markdown('Output Model')
output_model_name = gr.Text(label='Name of your model',
max_lines=1)
delete_existing_model = gr.Checkbox(
label='Delete existing model of the same name',
value=False)
with gr.Box():
gr.Markdown('Upload Settings')
with gr.Row():
upload_to_hub = gr.Checkbox(
label='Upload model to Hub', value=True)
use_private_repo = gr.Checkbox(label='Private',
value=True)
delete_existing_repo = gr.Checkbox(
label='Delete existing repo of the same name',
value=False)
upload_to = gr.Radio(
label='Upload to',
choices=[_.value for _ in UploadTarget],
value=UploadTarget.REALFILL_LIBRARY.value)
gr.Markdown('''
- By default, trained models will be uploaded to [ReaFill Library](https://huggingface.co/realfill-library).
- You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{your_username}/{model_name}.
''')
with gr.Box():
gr.Markdown('Training Parameters')
with gr.Row():
base_model = gr.Text(
label='Base Model',
value='stabilityai/stable-diffusion-2-inpainting',
max_lines=1)
resolution = gr.Dropdown(choices=['512', '768'],
value='512',
label='Resolution')
num_training_steps = gr.Number(
label='Number of Training Steps', value=2000, precision=0)
unet_learning_rate = gr.Number(label='Unet Learning Rate', value=0.0002)
text_encoder_learning_rate = gr.Number(label='Text Encoder Learning Rate', value=0.00004)
lora_rank = gr.Number(label='LoRA rank value', value=8, precision=0)
lora_dropout = gr.Number(label='LoRA dropout rate', value=0.1)
lora_alpha = gr.Number(label='LoRA alpha value', value=16, precision=0)
gradient_accumulation = gr.Number(
label='Number of Gradient Accumulation',
value=1,
precision=0)
seed = gr.Slider(label='Seed',
minimum=0,
maximum=100000,
step=1,
value=0)
fp16 = gr.Checkbox(label='FP16', value=True)
use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
checkpointing_steps = gr.Number(label='Checkpointing Steps',
value=100,
precision=0)
use_wandb = gr.Checkbox(label='Use W&B',
value=False,
interactive=bool(
os.getenv('WANDB_API_KEY')))
validation_steps = gr.Number(label='Validation Steps',
value=100,
precision=0)
gr.Markdown('''
- The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
- It takes a few minutes to download the base model first.
- It will take about 16 minutes to train for 2000 steps with a T4 GPU.
- You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
- You can check the training status by pressing the "Open logs" button if you are running this on your Space.
- You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
- **Note:** Due to [this issue](https://github.com/huggingface/accelerate/issues/944), currently, training will not terminate properly if you use W&B.
''')
remove_gpu_after_training = gr.Checkbox(
label='Remove GPU after training',
value=False,
interactive=bool(os.getenv('SPACE_ID')),
visible=False)
run_button = gr.Button('Start Training')
with gr.Box():
gr.Markdown('Output message')
output_message = gr.Markdown()
if pipe is not None:
run_button.click(fn=pipe.clear)
run_button.click(fn=trainer.run,
inputs=[
reference_images,
target_image,
target_mask,
output_model_name,
delete_existing_model,
base_model,
resolution,
num_training_steps,
unet_learning_rate,
text_encoder_learning_rate,
lora_rank,
lora_dropout,
lora_alpha,
gradient_accumulation,
seed,
fp16,
use_8bit_adam,
checkpointing_steps,
use_wandb,
validation_steps,
upload_to_hub,
use_private_repo,
delete_existing_repo,
upload_to,
remove_gpu_after_training,
],
outputs=output_message)
return demo
if __name__ == '__main__':
hf_token = os.getenv('HF_TOKEN')
trainer = Trainer(hf_token)
demo = create_training_demo(trainer)
demo.queue(max_size=1).launch(share=False)