brightpay's picture
Update app.py
641d496 verified
raw
history blame contribute delete
883 Bytes
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionPipeline
from safetensors.torch import load_file
import torch
def load_pipeline(base_model_path, lora_repo_id, lora_filename):
# Load base SD-v1.5 pipeline
pipeline = StableDiffusionPipeline.from_pretrained(base_model_path)
# Download the LoRA file
lora_path = hf_hub_download(repo_id=lora_repo_id, filename=lora_filename)
# Load LoRA weights
lora_state_dict = load_file(lora_path)
pipeline.unet.load_attn_procs(lora_state_dict)
return pipeline
# Define parameters
base_model_path = "runwayml/stable-diffusion-v1-5"
lora_repo_id = "maria26/Floor_Plan_LoRA"
lora_filename = "model.safetensors"
# Load the pipeline
pipeline = load_pipeline(base_model_path, lora_repo_id, lora_filename)
def predict(prompt):
result = pipeline(prompt).images[0]
return result