kohya_ss / kohya_gui /class_basic_training.py
zengxi123's picture
Upload folder using huggingface_hub
fb83c5b verified
import gradio as gr
from typing import Tuple
from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
class BasicTraining:
"""
This class configures and initializes the basic training settings for a machine learning model,
including options for SDXL, learning rate, learning rate scheduler, and training epochs.
Attributes:
sdxl_checkbox (gr.Checkbox): Checkbox to enable SDXL training.
learning_rate_value (str): Initial learning rate value.
lr_scheduler_value (str): Initial learning rate scheduler value.
lr_warmup_value (str): Initial learning rate warmup value.
finetuning (bool): If True, enables fine-tuning of the model.
dreambooth (bool): If True, enables Dreambooth training.
"""
def __init__(
self,
sdxl_checkbox: gr.Checkbox,
learning_rate_value: float = "1e-6",
lr_scheduler_value: str = "constant",
lr_warmup_value: float = "0",
finetuning: bool = False,
dreambooth: bool = False,
config: dict = {},
) -> None:
"""
Initializes the BasicTraining object with the given parameters.
Args:
sdxl_checkbox (gr.Checkbox): Checkbox to enable SDXL training.
learning_rate_value (str): Initial learning rate value.
lr_scheduler_value (str): Initial learning rate scheduler value.
lr_warmup_value (str): Initial learning rate warmup value.
finetuning (bool): If True, enables fine-tuning of the model.
dreambooth (bool): If True, enables Dreambooth training.
"""
self.sdxl_checkbox = sdxl_checkbox
self.learning_rate_value = learning_rate_value
self.lr_scheduler_value = lr_scheduler_value
self.lr_warmup_value = lr_warmup_value
self.finetuning = finetuning
self.dreambooth = dreambooth
self.config = config
self.old_lr_warmup = 0
# Initialize the UI components
self.initialize_ui_components()
def initialize_ui_components(self) -> None:
"""
Initializes the UI components for the training settings.
"""
# Initialize the training controls
self.init_training_controls()
# Initialize the precision and resources controls
self.init_precision_and_resources_controls()
# Initialize the learning rate and optimizer controls
self.init_lr_and_optimizer_controls()
# Initialize the gradient and learning rate controls
self.init_grad_and_lr_controls()
# Initialize the learning rate controls
self.init_learning_rate_controls()
# Initialize the scheduler controls
self.init_scheduler_controls()
# Initialize the resolution and bucket controls
self.init_resolution_and_bucket_controls()
# Setup the behavior of the SDXL checkbox
self.setup_sdxl_checkbox_behavior()
def init_training_controls(self) -> None:
"""
Initializes the training controls for the model.
"""
# Create a row for the training controls
with gr.Row():
# Initialize the train batch size slider
self.train_batch_size = gr.Slider(
minimum=1,
maximum=64,
label="Train batch size",
value=1,
step=self.config.get("basic.train_batch_size", 1),
)
# Initialize the epoch number input
self.epoch = gr.Number(
label="Epoch", value=self.config.get("basic.epoch", 1), precision=0
)
# Initialize the maximum train epochs input
self.max_train_epochs = gr.Number(
label="Max train epoch",
info="training epochs (overrides max_train_steps). 0 = no override",
step=1,
# precision=0,
minimum=0,
value=self.config.get("basic.max_train_epochs", 0),
)
# Initialize the maximum train steps input
self.max_train_steps = gr.Number(
label="Max train steps",
info="Overrides # training steps. 0 = no override",
step=1,
# precision=0,
value=self.config.get("basic.max_train_steps", 1600),
)
# Initialize the save every N epochs input
self.save_every_n_epochs = gr.Number(
label="Save every N epochs",
value=self.config.get("basic.save_every_n_epochs", 1),
precision=0,
)
# Initialize the caption extension input
self.caption_extension = gr.Dropdown(
label="Caption file extension",
choices=["", ".cap", ".caption", ".txt"],
value=".txt",
interactive=True,
)
def init_precision_and_resources_controls(self) -> None:
"""
Initializes the precision and resources controls for the model.
"""
with gr.Row():
# Initialize the seed textbox
self.seed = gr.Number(
label="Seed",
# precision=0,
step=1,
minimum=0,
value=self.config.get("basic.seed", 0),
info="Set to 0 to make random",
)
# Initialize the cache latents checkbox
self.cache_latents = gr.Checkbox(
label="Cache latents",
value=self.config.get("basic.cache_latents", True),
)
# Initialize the cache latents to disk checkbox
self.cache_latents_to_disk = gr.Checkbox(
label="Cache latents to disk",
value=self.config.get("basic.cache_latents_to_disk", False),
)
def init_lr_and_optimizer_controls(self) -> None:
"""
Initializes the learning rate and optimizer controls for the model.
"""
with gr.Row():
# Initialize the learning rate scheduler dropdown
self.lr_scheduler = gr.Dropdown(
label="LR Scheduler",
choices=[
"adafactor",
"constant",
"constant_with_warmup",
"cosine",
"cosine_with_restarts",
"linear",
"polynomial",
],
value=self.config.get("basic.lr_scheduler", self.lr_scheduler_value),
)
# Initialize the optimizer dropdown
self.optimizer = gr.Dropdown(
label="Optimizer",
choices=[
"AdamW",
"AdamW8bit",
"Adafactor",
"DAdaptation",
"DAdaptAdaGrad",
"DAdaptAdam",
"DAdaptAdan",
"DAdaptAdanIP",
"DAdaptAdamPreprint",
"DAdaptLion",
"DAdaptSGD",
"Lion",
"Lion8bit",
"PagedAdamW8bit",
"PagedAdamW32bit",
"PagedLion8bit",
"Prodigy",
"SGDNesterov",
"SGDNesterov8bit",
],
value=self.config.get("basic.optimizer", "AdamW8bit"),
interactive=True,
)
def init_grad_and_lr_controls(self) -> None:
"""
Initializes the gradient and learning rate controls for the model.
"""
with gr.Row():
# Initialize the maximum gradient norm slider
self.max_grad_norm = gr.Slider(
label="Max grad norm",
value=self.config.get("basic.max_grad_norm", 1.0),
minimum=0.0,
maximum=1.0,
interactive=True,
)
# Initialize the learning rate scheduler extra arguments textbox
self.lr_scheduler_args = gr.Textbox(
label="LR scheduler extra arguments",
lines=2,
placeholder="(Optional) eg: milestones=[1,10,30,50] gamma=0.1",
value=self.config.get("basic.lr_scheduler_args", ""),
)
# Initialize the optimizer extra arguments textbox
self.optimizer_args = gr.Textbox(
label="Optimizer extra arguments",
lines=2,
placeholder="(Optional) eg: relative_step=True scale_parameter=True warmup_init=True",
value=self.config.get("basic.optimizer_args", ""),
)
def init_learning_rate_controls(self) -> None:
"""
Initializes the learning rate controls for the model.
"""
with gr.Row():
# Adjust visibility based on training modes
lr_label = (
"Learning rate Unet"
if self.finetuning or self.dreambooth
else "Learning rate"
)
# Initialize the learning rate number input
self.learning_rate = gr.Number(
label=lr_label,
value=self.config.get("basic.learning_rate", self.learning_rate_value),
minimum=0,
maximum=1,
info="Set to 0 to not train the Unet",
)
# Initialize the learning rate TE number input
self.learning_rate_te = gr.Number(
label="Learning rate TE",
value=self.config.get(
"basic.learning_rate_te", self.learning_rate_value
),
visible=self.finetuning or self.dreambooth,
minimum=0,
maximum=1,
info="Set to 0 to not train the Text Encoder",
)
# Initialize the learning rate TE1 number input
self.learning_rate_te1 = gr.Number(
label="Learning rate TE1",
value=self.config.get(
"basic.learning_rate_te1", self.learning_rate_value
),
visible=False,
minimum=0,
maximum=1,
info="Set to 0 to not train the Text Encoder 1",
)
# Initialize the learning rate TE2 number input
self.learning_rate_te2 = gr.Number(
label="Learning rate TE2",
value=self.config.get(
"basic.learning_rate_te2", self.learning_rate_value
),
visible=False,
minimum=0,
maximum=1,
info="Set to 0 to not train the Text Encoder 2",
)
# Initialize the learning rate warmup slider
self.lr_warmup = gr.Slider(
label="LR warmup (% of total steps)",
value=self.config.get("basic.lr_warmup", self.lr_warmup_value),
minimum=0,
maximum=100,
step=1,
)
def lr_scheduler_changed(scheduler, value):
if scheduler == "constant":
self.old_lr_warmup = value
value = 0
interactive=False
info="Can't use LR warmup with LR Scheduler constant... setting to 0 and disabling field..."
else:
if self.old_lr_warmup != 0:
value = self.old_lr_warmup
self.old_lr_warmup = 0
interactive=True
info=""
return gr.Slider(value=value, interactive=interactive, info=info)
self.lr_scheduler.change(
lr_scheduler_changed,
inputs=[self.lr_scheduler, self.lr_warmup],
outputs=self.lr_warmup,
)
def init_scheduler_controls(self) -> None:
"""
Initializes the scheduler controls for the model.
"""
with gr.Row(visible=not self.finetuning):
# Initialize the learning rate scheduler number of cycles textbox
self.lr_scheduler_num_cycles = gr.Number(
label="LR # cycles",
minimum=1,
# precision=0, # round to nearest integer
step=1, # Increment value by 1
info="Number of restarts for cosine scheduler with restarts",
value=self.config.get("basic.lr_scheduler_num_cycles", 1),
)
# Initialize the learning rate scheduler power textbox
self.lr_scheduler_power = gr.Number(
label="LR power",
minimum=0.0,
step=0.01,
info="Polynomial power for polynomial scheduler",
value=self.config.get("basic.lr_scheduler_power", 1.0),
)
def init_resolution_and_bucket_controls(self) -> None:
"""
Initializes the resolution and bucket controls for the model.
"""
with gr.Row(visible=not self.finetuning):
# Initialize the maximum resolution textbox
self.max_resolution = gr.Textbox(
label="Max resolution",
value=self.config.get("basic.max_resolution", "512,512"),
placeholder="512,512",
)
# Initialize the stop text encoder training slider
self.stop_text_encoder_training = gr.Slider(
minimum=-1,
maximum=100,
value=self.config.get("basic.stop_text_encoder_training", 0),
step=1,
label="Stop TE (% of total steps)",
)
# Initialize the enable buckets checkbox
self.enable_bucket = gr.Checkbox(
label="Enable buckets",
value=self.config.get("basic.enable_bucket", True),
)
# Initialize the minimum bucket resolution slider
self.min_bucket_reso = gr.Slider(
label="Minimum bucket resolution",
value=self.config.get("basic.min_bucket_reso", 256),
minimum=64,
maximum=4096,
step=64,
info="Minimum size in pixel a bucket can be (>= 64)",
)
# Initialize the maximum bucket resolution slider
self.max_bucket_reso = gr.Slider(
label="Maximum bucket resolution",
value=self.config.get("basic.max_bucket_reso", 2048),
minimum=64,
maximum=4096,
step=64,
info="Maximum size in pixel a bucket can be (>= 64)",
)
def setup_sdxl_checkbox_behavior(self) -> None:
"""
Sets up the behavior of the SDXL checkbox based on the finetuning and dreambooth flags.
"""
self.sdxl_checkbox.change(
self.update_learning_rate_te,
inputs=[
self.sdxl_checkbox,
gr.Checkbox(value=self.finetuning, visible=False),
gr.Checkbox(value=self.dreambooth, visible=False),
],
outputs=[
self.learning_rate_te,
self.learning_rate_te1,
self.learning_rate_te2,
],
)
def update_learning_rate_te(
self,
sdxl_checkbox: gr.Checkbox,
finetuning: bool,
dreambooth: bool,
) -> Tuple[gr.Number, gr.Number, gr.Number]:
"""
Updates the visibility of the learning rate TE, TE1, and TE2 based on the SDXL checkbox and finetuning/dreambooth flags.
Args:
sdxl_checkbox (gr.Checkbox): The SDXL checkbox.
finetuning (bool): Whether finetuning is enabled.
dreambooth (bool): Whether dreambooth is enabled.
Returns:
Tuple[gr.Number, gr.Number, gr.Number]: A tuple containing the updated visibility for learning rate TE, TE1, and TE2.
"""
# Determine the visibility condition based on finetuning and dreambooth flags
visibility_condition = finetuning or dreambooth
# Return a tuple of gr.Number instances with updated visibility
return (
gr.Number(visible=(not sdxl_checkbox and visibility_condition)),
gr.Number(visible=(sdxl_checkbox and visibility_condition)),
gr.Number(visible=(sdxl_checkbox and visibility_condition)),
)