Spaces:
Sleeping
Sleeping
File size: 3,350 Bytes
293db3f 66f79e7 ce6e822 66f79e7 2c282ba ce6e822 2c282ba 66f79e7 2c282ba 293db3f 66f79e7 293db3f 66f79e7 293db3f 66f79e7 db718be 40c1fd3 66f79e7 db718be 293db3f 66f79e7 293db3f 66f79e7 2c282ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gradio as gr
import os
import shutil
from main import fine_tune_model
from diffusers import StableDiffusionPipeline, DDIMScheduler
import torch
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
OUTPUT_DIR = "/home/user/app/stable_diffusion_weights/custom_model"
def fine_tune(instance_prompt, image1, image2=None):
instance_data_dir = "/home/user/app/instance_images"
try:
if os.path.exists(instance_data_dir):
shutil.rmtree(instance_data_dir)
os.makedirs(instance_data_dir, exist_ok=True)
image1.save(os.path.join(instance_data_dir, "instance_0.png"))
if image2 is not None:
image2.save(os.path.join(instance_data_dir, "instance_1.png"))
fine_tune_model(instance_data_dir, instance_prompt, MODEL_NAME, OUTPUT_DIR)
return "Model fine-tuning complete."
except Exception as e:
return str(e)
def generate_images(prompt, num_samples, height, width, num_inference_steps, guidance_scale):
try:
if not os.path.exists(OUTPUT_DIR):
return "The model path does not exist."
pipe = StableDiffusionPipeline.from_pretrained(OUTPUT_DIR, safety_checker=None, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
g_cuda = torch.Generator(device='cuda').manual_seed(1337)
with torch.autocast("cuda"), torch.inference_mode():
images = pipe(
prompt, height=height, width=width, num_images_per_prompt=num_samples,
num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=g_cuda
).images
return images
except Exception as e:
return str(e)
def gradio_app():
with gr.Blocks() as demo:
with gr.Tab("Fine-Tune Model"):
with gr.Row():
with gr.Column():
instance_prompt = gr.Textbox(label="Instance Prompt")
image1 = gr.Image(label="Upload Image 1", type="pil")
image2 = gr.Image(label="Upload Image 2 (Optional)", type="pil")
fine_tune_button = gr.Button("Fine-Tune Model")
output_text = gr.Textbox(label="Output")
fine_tune_button.click(fine_tune, inputs=[instance_prompt, image1, image2], outputs=output_text)
with gr.Tab("Generate Images"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
num_samples = gr.Number(label="Number of Samples", value=1)
guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
height = gr.Number(label="Height", value=512)
width = gr.Number(label="Width", value=512)
num_inference_steps = gr.Slider(label="Steps", value=50, minimum=1, maximum=100)
generate_button = gr.Button("Generate Images")
with gr.Column():
gallery = gr.Gallery(label="Generated Images")
generate_button.click(generate_images, inputs=[prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)
demo.launch()
if __name__ == "__main__":
gradio_app() |