File size: 4,020 Bytes
d1d1942
 
 
0259e4a
d1d1942
0259e4a
 
 
d1d1942
 
0259e4a
 
d1d1942
0259e4a
 
d1d1942
0259e4a
 
d1d1942
 
 
 
0259e4a
 
 
 
 
 
d1d1942
0259e4a
 
 
 
 
 
 
 
 
 
d1d1942
 
 
0259e4a
d1d1942
0259e4a
 
 
 
 
 
d1d1942
 
0259e4a
 
d1d1942
 
 
 
0259e4a
 
 
 
d1d1942
0259e4a
 
d1d1942
0259e4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1d1942
0259e4a
 
d1d1942
 
 
 
 
 
 
0259e4a
 
d1d1942
0259e4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from tqdm.auto import tqdm
from huggingface_hub import hf_hub_url, login, HfApi, create_repo
import os
import traceback
from peft import PeftModel
import gradio as gr

def display_image(image):
    """Display the generated image."""
    return image 

def load_and_merge_lora(base_model_id, lora_id, lora_adapter_name):
    try:
        pipe = DiffusionPipeline.from_pretrained(
            base_model_id,
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True,
        ).to("cuda")

        pipe.scheduler = DPMSolverMultistepScheduler.from_config(
            pipe.scheduler.config
        )

        # Get the UNet model from the pipeline
        unet = pipe.unet

        # Apply PEFT to the UNet model 
        unet = PeftModel.from_pretrained(
            unet, 
            lora_id, 
            torch_dtype=torch.float16, 
            adapter_name=lora_adapter_name
        )

        # Replace the original UNet in the pipeline with the PEFT-loaded one
        pipe.unet = unet

        print("LoRA merged successfully!")
        return pipe

    except Exception as e:
        error_msg = traceback.format_exc()  
        print(f"Error merging LoRA: {e}\n\nFull traceback saved to errors.txt")

        with open("errors.txt", "w") as f:
            f.write(error_msg)

        return None

def save_merged_model(pipe, save_path, push_to_hub=False, hf_token=None):
    """Saves and optionally pushes the merged model to Hugging Face Hub."""
    try:
        pipe.save_pretrained(save_path)
        print(f"Merged model saved successfully to: {save_path}")

        if push_to_hub:
            if hf_token is None:
                hf_token = input("Enter your Hugging Face write token: ")
                login(token=hf_token)

            repo_name = input("Enter the Hugging Face repository name "
                              "(e.g., your_username/your_model_name): ")

            # Create the repository if it doesn't exist
            create_repo(repo_name, token=hf_token, exist_ok=True) 

            api = HfApi()
            api.upload_folder(
                folder_path=save_path,
                repo_id=repo_name,
                token=hf_token,
                repo_type="model",
            )
            print(f"Model pushed successfully to Hugging Face Hub: {repo_name}")

    except Exception as e:
        print(f"Error saving/pushing the merged model: {e}")

def generate_and_save(base_model_id, lora_id, lora_adapter_name, prompt, lora_scale, save_path, push_to_hub, hf_token):
    pipe = load_and_merge_lora(base_model_id, lora_id, lora_adapter_name)

    if pipe:
        lora_scale = float(lora_scale)
        image = pipe(
            prompt, 
            num_inference_steps=30, 
            cross_attention_kwargs={"scale": lora_scale}, 
            generator=torch.manual_seed(0)
        ).images[0]

        image.save("generated_image.png")
        print(f"Image saved to: generated_image.png")

        save_merged_model(pipe, save_path, push_to_hub, hf_token)

        return image, "Image generated and model saved/pushed (if selected)."

iface = gr.Interface(
    fn=generate_and_save,
    inputs=[
        gr.Textbox(label="Base Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)"),
        gr.Textbox(label="LoRA ID (e.g., your_username/your_lora)"),
        gr.Textbox(label="LoRA Adapter Name"),
        gr.Textbox(label="Prompt"),
        gr.Slider(label="LoRA Scale", minimum=0.0, maximum=1.0, value=0.7, step=0.1),
        gr.Textbox(label="Save Path"),
        gr.Checkbox(label="Push to Hugging Face Hub"),
        gr.Textbox(label="Hugging Face Write Token", type="password")
    ],
    outputs=[
        gr.Image(label="Generated Image"),
        gr.Textbox(label="Status")
    ],
    title="LoRA Merger and Image Generator",
    description="Merge a LoRA with a base Stable Diffusion model and generate images."
)

iface.launch()