File size: 4,186 Bytes
fb83c5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
import subprocess
import os
import sys
from .common_gui import (
    get_file_path,
    scriptdir,
    list_files,
    create_refresh_button, setup_environment
)

from .custom_logging import setup_logging

# Set up logging
log = setup_logging()

folder_symbol = "\U0001f4c2"  # πŸ“‚
refresh_symbol = "\U0001f504"  # πŸ”„
save_style_symbol = "\U0001f4be"  # πŸ’Ύ
document_symbol = "\U0001F4C4"  # πŸ“„
PYTHON = sys.executable


def verify_lora(

    lora_model,

):
    # verify for caption_text_input
    if lora_model == "":
        log.info("Invalid model A file")
        return

    # verify if source model exist
    if not os.path.isfile(lora_model):
        log.info("The provided model A is not a file")
        return

    run_cmd = [
        rf"{PYTHON}",
        rf"{scriptdir}/sd-scripts/networks/check_lora_weights.py",
        rf"{lora_model}",
    ]
    # run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/check_lora_weights.py" "{lora_model}"'

    # Reconstruct the safe command string for display
    command_to_run = " ".join(run_cmd)
    log.info(f"Executing command: {command_to_run}")

    # Set the environment variable for the Python path
    env = setup_environment()

    # Run the command using subprocess.Popen for asynchronous handling
    process = subprocess.Popen(
        run_cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        env=env,
    )
    output, error = process.communicate()

    return (output.decode(), error.decode())


###
# Gradio UI
###


def gradio_verify_lora_tab(headless=False):
    current_model_dir = os.path.join(scriptdir, "outputs")

    def list_models(path):
        nonlocal current_model_dir
        current_model_dir = path
        return list(list_files(path, exts=[".pt", ".safetensors"], all=True))

    with gr.Tab("Verify LoRA"):
        gr.Markdown(
            "This utility can verify a LoRA network to make sure it is properly trained."
        )

        lora_ext = gr.Textbox(value="*.pt *.safetensors", visible=False)
        lora_ext_name = gr.Textbox(value="LoRA model types", visible=False)

        with gr.Group(), gr.Row():
            lora_model = gr.Dropdown(
                label="LoRA model (path to the LoRA model to verify)",
                interactive=True,
                choices=[""] + list_models(current_model_dir),
                value="",
                allow_custom_value=True,
            )
            create_refresh_button(
                lora_model,
                lambda: None,
                lambda: {"choices": list_models(current_model_dir)},
                "open_folder_small",
            )
            button_lora_model_file = gr.Button(
                folder_symbol,
                elem_id="open_folder_small",
                elem_classes=["tool"],
                visible=(not headless),
            )
            button_lora_model_file.click(
                get_file_path,
                inputs=[lora_model, lora_ext, lora_ext_name],
                outputs=lora_model,
                show_progress=False,
            )
            verify_button = gr.Button("Verify", variant="primary")

            lora_model.change(
                fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)),
                inputs=lora_model,
                outputs=lora_model,
                show_progress=False,
            )

        lora_model_verif_output = gr.Textbox(
            label="Output",
            placeholder="Verification output",
            interactive=False,
            lines=1,
            max_lines=10,
        )

        lora_model_verif_error = gr.Textbox(
            label="Error",
            placeholder="Verification error",
            interactive=False,
            lines=1,
            max_lines=10,
        )

        verify_button.click(
            verify_lora,
            inputs=[
                lora_model,
            ],
            outputs=[lora_model_verif_output, lora_model_verif_error],
            show_progress=False,
        )