brightpay's picture
Update app.py
7d0642a verified
raw
history blame
856 Bytes
from diffusers import StableDiffusionPipeline
from safetensors.torch import load_file
import torch
def load_pipeline(base_model_path, lora_path):
# Load base SD-v1.5 pipeline
pipeline = StableDiffusionPipeline.from_pretrained(base_model_path)
# Load LoRA weights
lora_state_dict = load_file(lora_path)
pipeline.unet.load_attn_procs(lora_state_dict)
return pipeline
# Paths to the base model and LoRA weights
base_model_path = "path/to/sd-v1-5"
lora_path = "path/to/Floor_Plan_LoRA.safetensors"
# Load the pipeline
pipeline = load_pipeline(base_model_path, lora_path)
def predict(prompt):
# Generate an image based on the prompt
result = pipeline(prompt).images[0]
return result
# Create Gradio Interface
import gradio as gr
interface = gr.Interface(fn=predict, inputs="text", outputs="image")
interface.launch()