import gradio as gr import os from .common_gui import ( get_file_path, get_folder_path, set_pretrained_model_name_or_path_input, scriptdir, list_dirs, list_files, create_refresh_button, ) from .class_gui_config import KohyaSSGUIConfig folder_symbol = "\U0001f4c2" # 📂 refresh_symbol = "\U0001f504" # 🔄 save_style_symbol = "\U0001f4be" # 💾 document_symbol = "\U0001F4C4" # 📄 default_models = [ "stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned", "stabilityai/stable-diffusion-2-1-base", "stabilityai/stable-diffusion-2-base", "stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned", "stabilityai/stable-diffusion-2-1", "stabilityai/stable-diffusion-2", "runwayml/stable-diffusion-v1-5", "CompVis/stable-diffusion-v1-4", ] class SourceModel: def __init__( self, save_model_as_choices=[ "same as source model", "ckpt", "diffusers", "diffusers_safetensors", "safetensors", ], save_precision_choices=[ "float", "fp16", "bf16", ], headless=False, finetuning=False, config: KohyaSSGUIConfig = {}, ): self.headless = headless self.save_model_as_choices = save_model_as_choices self.finetuning = finetuning self.config = config # Set default directories if not provided self.current_models_dir = self.config.get( "model.models_dir", os.path.join(scriptdir, "models") ) self.current_train_data_dir = self.config.get( "model.train_data_dir", os.path.join(scriptdir, "data") ) self.current_dataset_config_dir = self.config.get( "model.dataset_config", os.path.join(scriptdir, "dataset_config") ) model_checkpoints = list( list_files( self.current_models_dir, exts=[".ckpt", ".safetensors"], all=True ) ) def list_models(path): self.current_models_dir = ( path if os.path.isdir(path) else os.path.dirname(path) ) return default_models + list( list_files(path, exts=[".ckpt", ".safetensors"], all=True) ) def list_train_data_dirs(path): self.current_train_data_dir = path if not path == "" else "." return list(list_dirs(self.current_train_data_dir)) def list_dataset_config_dirs(path: str) -> list: """ List directories and toml files in the dataset_config directory. Parameters: - path (str): The path to list directories and files from. Returns: - list: A list of directories and files. """ current_dataset_config_dir = path if not path == "" else "." # Lists all .json files in the current configuration directory, used for populating dropdown choices. return list( list_files(current_dataset_config_dir, exts=[".toml"], all=True) ) with gr.Accordion("Model", open=True): with gr.Column(), gr.Group(): model_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False) model_ext_name = gr.Textbox(value="Model types", visible=False) # Define the input elements with gr.Row(): with gr.Column(), gr.Row(): self.model_list = gr.Textbox(visible=False, value="") self.pretrained_model_name_or_path = gr.Dropdown( label="Pretrained model name or path", choices=default_models + model_checkpoints, value=self.config.get("model.models_dir", "runwayml/stable-diffusion-v1-5"), allow_custom_value=True, visible=True, min_width=100, ) create_refresh_button( self.pretrained_model_name_or_path, lambda: None, lambda: {"choices": list_models(self.current_models_dir)}, "open_folder_small", ) self.pretrained_model_name_or_path_file = gr.Button( document_symbol, elem_id="open_folder_small", elem_classes=["tool"], visible=(not headless), ) self.pretrained_model_name_or_path_file.click( get_file_path, inputs=[self.pretrained_model_name_or_path, model_ext, model_ext_name], outputs=self.pretrained_model_name_or_path, show_progress=False, ) self.pretrained_model_name_or_path_folder = gr.Button( folder_symbol, elem_id="open_folder_small", elem_classes=["tool"], visible=(not headless), ) self.pretrained_model_name_or_path_folder.click( get_folder_path, inputs=self.pretrained_model_name_or_path, outputs=self.pretrained_model_name_or_path, show_progress=False, ) with gr.Column(), gr.Row(): self.output_name = gr.Textbox( label="Trained Model output name", placeholder="(Name of the model to output)", value=self.config.get("model.output_name", "last"), interactive=True, ) with gr.Row(): with gr.Column(), gr.Row(): self.train_data_dir = gr.Dropdown( label=( "Image folder (containing training images subfolders)" if not finetuning else "Image folder (containing training images)" ), choices=[""] + list_train_data_dirs(self.current_train_data_dir), value=self.config.get("model.train_data_dir", ""), interactive=True, allow_custom_value=True, ) create_refresh_button( self.train_data_dir, lambda: None, lambda: { "choices": [""] + list_train_data_dirs(self.current_train_data_dir) }, "open_folder_small", ) self.train_data_dir_folder = gr.Button( "📂", elem_id="open_folder_small", elem_classes=["tool"], visible=(not self.headless), ) self.train_data_dir_folder.click( get_folder_path, outputs=self.train_data_dir, show_progress=False, ) with gr.Column(), gr.Row(): # Toml directory dropdown self.dataset_config = gr.Dropdown( label="Dataset config file (Optional. Select the toml configuration file to use for the dataset)", choices=[self.config.get("model.dataset_config", "")] + list_dataset_config_dirs(self.current_dataset_config_dir), value=self.config.get("model.dataset_config", ""), interactive=True, allow_custom_value=True, ) # Refresh button for dataset_config directory create_refresh_button( self.dataset_config, lambda: None, lambda: { "choices": [""] + list_dataset_config_dirs( self.current_dataset_config_dir ) }, "open_folder_small", ) # Toml directory button self.dataset_config_folder = gr.Button( document_symbol, elem_id="open_folder_small", elem_classes=["tool"], visible=(not self.headless), ) # Toml directory button click event self.dataset_config_folder.click( get_file_path, inputs=[ self.dataset_config, gr.Textbox(value="*.toml", visible=False), gr.Textbox(value="Dataset config types", visible=False), ], outputs=self.dataset_config, show_progress=False, ) # Change event for dataset_config directory dropdown self.dataset_config.change( fn=lambda path: gr.Dropdown( choices=[""] + list_dataset_config_dirs(path) ), inputs=self.dataset_config, outputs=self.dataset_config, show_progress=False, ) with gr.Row(): with gr.Column(): with gr.Row(): self.v2 = gr.Checkbox( label="v2", value=False, visible=False, min_width=60 ) self.v_parameterization = gr.Checkbox( label="v_parameterization", value=False, visible=False, min_width=130, ) self.sdxl_checkbox = gr.Checkbox( label="SDXL", value=False, visible=False, min_width=60, ) with gr.Column(): gr.Group(visible=False) with gr.Row(): self.training_comment = gr.Textbox( label="Training comment", placeholder="(Optional) Add training comment to be included in metadata", interactive=True, value=self.config.get("model.training_comment", ""), ) with gr.Row(): self.save_model_as = gr.Radio( save_model_as_choices, label="Save trained model as", value=self.config.get("model.save_model_as", "safetensors"), ) self.save_precision = gr.Radio( save_precision_choices, label="Save precision", value=self.config.get("model.save_precision", "fp16"), ) self.pretrained_model_name_or_path.change( fn=lambda path: set_pretrained_model_name_or_path_input( path, refresh_method=list_models ), inputs=[ self.pretrained_model_name_or_path, ], outputs=[ self.pretrained_model_name_or_path, self.v2, self.v_parameterization, self.sdxl_checkbox, ], show_progress=False, ) self.train_data_dir.change( fn=lambda path: gr.Dropdown( choices=[""] + list_train_data_dirs(path) ), inputs=self.train_data_dir, outputs=self.train_data_dir, show_progress=False, )