Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
from __future__ import annotations | |
import os | |
import gradio as gr | |
from constants import UploadTarget | |
from inference import InferencePipeline | |
from trainer import Trainer | |
def create_training_demo(trainer: Trainer, | |
pipe: InferencePipeline | None = None) -> gr.Blocks: | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Box(): | |
gr.Markdown('Training Data') | |
reference_images = gr.Files(label='Reference images') | |
target_image = gr.Files(label='Target image') | |
target_mask = gr.Files(label='Target mask') | |
gr.Markdown(''' | |
- Upload reference images of the scene you are planning on training on. | |
- For the target image, the inpainting region should be white. | |
- For the target mask, white for inpainting and black for keeping as is. | |
''') | |
with gr.Box(): | |
gr.Markdown('Output Model') | |
output_model_name = gr.Text(label='Name of your model', | |
max_lines=1) | |
delete_existing_model = gr.Checkbox( | |
label='Delete existing model of the same name', | |
value=False) | |
with gr.Box(): | |
gr.Markdown('Upload Settings') | |
with gr.Row(): | |
upload_to_hub = gr.Checkbox( | |
label='Upload model to Hub', value=True) | |
use_private_repo = gr.Checkbox(label='Private', | |
value=True) | |
delete_existing_repo = gr.Checkbox( | |
label='Delete existing repo of the same name', | |
value=False) | |
upload_to = gr.Radio( | |
label='Upload to', | |
choices=[_.value for _ in UploadTarget], | |
value=UploadTarget.REALFILL_LIBRARY.value) | |
gr.Markdown(''' | |
- By default, trained models will be uploaded to [ReaFill Library](https://huggingface.co/realfill-library). | |
- You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{your_username}/{model_name}. | |
''') | |
with gr.Box(): | |
gr.Markdown('Training Parameters') | |
with gr.Row(): | |
base_model = gr.Text( | |
label='Base Model', | |
value='stabilityai/stable-diffusion-2-inpainting', | |
max_lines=1) | |
resolution = gr.Dropdown(choices=['512', '768'], | |
value='512', | |
label='Resolution') | |
num_training_steps = gr.Number( | |
label='Number of Training Steps', value=2000, precision=0) | |
unet_learning_rate = gr.Number(label='Unet Learning Rate', value=0.0002) | |
text_encoder_learning_rate = gr.Number(label='Text Encoder Learning Rate', value=0.00004) | |
lora_rank = gr.Number(label='LoRA rank value', value=8, precision=0) | |
lora_dropout = gr.Number(label='LoRA dropout rate', value=0.1) | |
lora_alpha = gr.Number(label='LoRA alpha value', value=16, precision=0) | |
gradient_accumulation = gr.Number( | |
label='Number of Gradient Accumulation', | |
value=1, | |
precision=0) | |
seed = gr.Slider(label='Seed', | |
minimum=0, | |
maximum=100000, | |
step=1, | |
value=0) | |
fp16 = gr.Checkbox(label='FP16', value=True) | |
use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True) | |
checkpointing_steps = gr.Number(label='Checkpointing Steps', | |
value=100, | |
precision=0) | |
use_wandb = gr.Checkbox(label='Use W&B', | |
value=False, | |
interactive=bool( | |
os.getenv('WANDB_API_KEY'))) | |
validation_steps = gr.Number(label='Validation Steps', | |
value=100, | |
precision=0) | |
gr.Markdown(''' | |
- The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library. | |
- It takes a few minutes to download the base model first. | |
- It will take about 16 minutes to train for 2000 steps with a T4 GPU. | |
- You may want to try a small number of steps first, like 1, to see if everything works fine in your environment. | |
- You can check the training status by pressing the "Open logs" button if you are running this on your Space. | |
- You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables). | |
- **Note:** Due to [this issue](https://github.com/huggingface/accelerate/issues/944), currently, training will not terminate properly if you use W&B. | |
''') | |
remove_gpu_after_training = gr.Checkbox( | |
label='Remove GPU after training', | |
value=False, | |
interactive=bool(os.getenv('SPACE_ID')), | |
visible=False) | |
run_button = gr.Button('Start Training') | |
with gr.Box(): | |
gr.Markdown('Output message') | |
output_message = gr.Markdown() | |
if pipe is not None: | |
run_button.click(fn=pipe.clear) | |
run_button.click(fn=trainer.run, | |
inputs=[ | |
reference_images, | |
target_image, | |
target_mask, | |
output_model_name, | |
delete_existing_model, | |
base_model, | |
resolution, | |
num_training_steps, | |
unet_learning_rate, | |
text_encoder_learning_rate, | |
lora_rank, | |
lora_dropout, | |
lora_alpha, | |
gradient_accumulation, | |
seed, | |
fp16, | |
use_8bit_adam, | |
checkpointing_steps, | |
use_wandb, | |
validation_steps, | |
upload_to_hub, | |
use_private_repo, | |
delete_existing_repo, | |
upload_to, | |
remove_gpu_after_training, | |
], | |
outputs=output_message) | |
return demo | |
if __name__ == '__main__': | |
hf_token = os.getenv('HF_TOKEN') | |
trainer = Trainer(hf_token) | |
demo = create_training_demo(trainer) | |
demo.queue(max_size=1).launch(share=False) | |