import gradio as gr import os import shutil from main import fine_tune_model from diffusers import StableDiffusionPipeline, DDIMScheduler import torch MODEL_NAME = "runwayml/stable-diffusion-v1-5" OUTPUT_DIR = "/home/user/app/stable_diffusion_weights/custom_model" def fine_tune(instance_prompt, image1, image2=None): instance_data_dir = "/home/user/app/instance_images" try: if os.path.exists(instance_data_dir): shutil.rmtree(instance_data_dir) os.makedirs(instance_data_dir, exist_ok=True) image1.save(os.path.join(instance_data_dir, "instance_0.png")) if image2 is not None: image2.save(os.path.join(instance_data_dir, "instance_1.png")) fine_tune_model(instance_data_dir, instance_prompt, MODEL_NAME, OUTPUT_DIR) return "Model fine-tuning complete." except Exception as e: return str(e) def generate_images(prompt, num_samples, height, width, num_inference_steps, guidance_scale): try: if not os.path.exists(OUTPUT_DIR): return "The model path does not exist." pipe = StableDiffusionPipeline.from_pretrained(OUTPUT_DIR, safety_checker=None, torch_dtype=torch.float16).to("cuda") pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) g_cuda = torch.Generator(device='cuda').manual_seed(1337) with torch.autocast("cuda"), torch.inference_mode(): images = pipe( prompt, height=height, width=width, num_images_per_prompt=num_samples, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=g_cuda ).images return images except Exception as e: return str(e) def gradio_app(): with gr.Blocks() as demo: with gr.Tab("Fine-Tune Model"): with gr.Row(): with gr.Column(): instance_prompt = gr.Textbox(label="Instance Prompt") image1 = gr.Image(label="Upload Image 1", type="pil") image2 = gr.Image(label="Upload Image 2 (Optional)", type="pil") fine_tune_button = gr.Button("Fine-Tune Model") output_text = gr.Textbox(label="Output") fine_tune_button.click(fine_tune, inputs=[instance_prompt, image1, image2], outputs=output_text) with gr.Tab("Generate Images"): with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt") num_samples = gr.Number(label="Number of Samples", value=1) guidance_scale = gr.Number(label="Guidance Scale", value=7.5) height = gr.Number(label="Height", value=512) width = gr.Number(label="Width", value=512) num_inference_steps = gr.Slider(label="Steps", value=50, minimum=1, maximum=100) generate_button = gr.Button("Generate Images") with gr.Column(): gallery = gr.Gallery(label="Generated Images") generate_button.click(generate_images, inputs=[prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery) demo.launch() if __name__ == "__main__": gradio_app()