File size: 883 Bytes
92ef194
7d0642a
 
 
 
641d496
7d0642a
 
 
92ef194
 
 
7d0642a
 
 
8766af7
7d0642a
bc883e3
92ef194
 
 
 
8766af7
7d0642a
92ef194
8766af7
7d0642a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionPipeline
from safetensors.torch import load_file
import torch

def load_pipeline(base_model_path, lora_repo_id, lora_filename):
    # Load base SD-v1.5 pipeline
    pipeline = StableDiffusionPipeline.from_pretrained(base_model_path)

    # Download the LoRA file
    lora_path = hf_hub_download(repo_id=lora_repo_id, filename=lora_filename)

    # Load LoRA weights
    lora_state_dict = load_file(lora_path)
    pipeline.unet.load_attn_procs(lora_state_dict)

    return pipeline

# Define parameters
base_model_path = "runwayml/stable-diffusion-v1-5"
lora_repo_id = "maria26/Floor_Plan_LoRA"
lora_filename = "model.safetensors"

# Load the pipeline
pipeline = load_pipeline(base_model_path, lora_repo_id, lora_filename)

def predict(prompt):
    result = pipeline(prompt).images[0]
    return result