File size: 5,007 Bytes
ddea0a0
 
 
 
4c3e795
ddea0a0
 
 
 
 
 
 
 
 
 
 
 
 
 
496c3d0
2ad93a6
ddea0a0
2ad93a6
ddea0a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698aefd
ddea0a0
 
 
 
 
 
 
698aefd
ddea0a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
078397a
4c3e795
078397a
ddea0a0
 
 
 
77e3547
ddea0a0
 
 
fa662f9
 
ddea0a0
 
 
 
 
77e3547
ddea0a0
77e3547
078397a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5db164b
078397a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import annotations

import gc
import pathlib
import spaces

import gradio as gr
import PIL.Image
import torch
from diffusers import StableDiffusionXLPipeline
from huggingface_hub import ModelCard

from blora_utils import BLOCKS, filter_lora, scale_lora


class InferencePipeline:
    def __init__(self, hf_token: str | None = None):
        self.hf_token = hf_token
        self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.pipe = StableDiffusionXLPipeline.from_pretrained(
                self.base_model_id,
                torch_dtype=torch.float16,
                use_auth_token=self.hf_token)
        self.content_lora_model_id = None
        self.style_lora_model_id = None

    def clear(self) -> None:
        self.content_lora_model_id = None
        self.style_lora_model_id = None
        del self.pipe
        self.pipe = None
        torch.cuda.empty_cache()
        gc.collect()

    def load_b_lora_to_unet(self, content_lora_model_id: str, style_lora_model_id: str, content_alpha: float,
                            style_alpha: float) -> None:
        try:
            # Get Content B-LoRA SD
            if content_lora_model_id and content_lora_model_id != 'None':
                content_B_LoRA_sd, _ = self.pipe.lora_state_dict(content_lora_model_id, use_auth_token=self.hf_token)
                content_B_LoRA = filter_lora(content_B_LoRA_sd, BLOCKS['content'])
                content_B_LoRA = scale_lora(content_B_LoRA, content_alpha)
            else:
                content_B_LoRA = {}

            # Get Style B-LoRA SD
            if style_lora_model_id and style_lora_model_id != 'None':
                style_B_LoRA_sd, _ = self.pipe.lora_state_dict(style_lora_model_id, use_auth_token=self.hf_token)
                style_B_LoRA = filter_lora(style_B_LoRA_sd, BLOCKS['style'])
                style_B_LoRA = scale_lora(style_B_LoRA, style_alpha)
            else:
                style_B_LoRA = {}

            # Merge B-LoRAs SD
            res_lora = {**content_B_LoRA, **style_B_LoRA}

            # Load
            self.pipe.load_lora_into_unet(res_lora, None, self.pipe.unet)
        except Exception as e:
            raise type(e)(f'failed to load_b_lora_to_unet, due to: {e}')

    @staticmethod
    def check_if_model_is_local(lora_model_id: str) -> bool:
        return pathlib.Path(lora_model_id).exists()

    @staticmethod
    def get_model_card(model_id: str,
                       hf_token: str | None = None) -> ModelCard:
        if InferencePipeline.check_if_model_is_local(model_id):
            card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
        else:
            card_path = model_id
        return ModelCard.load(card_path, token=hf_token)

    @staticmethod
    def get_base_model_info(lora_model_id: str,
                            hf_token: str | None = None) -> str:
        card = InferencePipeline.get_model_card(lora_model_id, hf_token)
        return card.data.base_model

    def load_pipe(self, content_lora_model_id: str, style_lora_model_id: str, content_alpha: float,
                  style_alpha: float) -> None:
        if content_lora_model_id == self.content_lora_model_id and style_lora_model_id == self.style_lora_model_id:
            return
        self.pipe.unload_lora_weights()

        self.load_b_lora_to_unet(content_lora_model_id, style_lora_model_id, content_alpha, style_alpha)

        self.content_lora_model_id = content_lora_model_id
        self.style_lora_model_id = style_lora_model_id

    @spaces.GPU
    def inference(self,
            prompt: str,
            seed: int,
            n_steps: int,
            guidance_scale: float,
            num_images_per_prompt: int = 1
    ) -> PIL.Image.Image:
        if not torch.cuda.is_available():
            raise gr.Error('CUDA is not available.')
        self.pipe.to("cuda")
        generator = torch.Generator(device="cuda").manual_seed(seed)
        out = self.pipe(
            prompt,
            num_inference_steps=n_steps,
            guidance_scale=guidance_scale,
            generator=generator,
            num_images_per_prompt=num_images_per_prompt,
        )  # type: ignore
        return out.images
    
    
    def run(
            self,
            content_lora_model_id: str,
            style_lora_model_id: str,
            prompt: str,
            content_alpha: float,
            style_alpha: float,
            seed: int,
            n_steps: int,
            guidance_scale: float,
            num_images_per_prompt: int = 1
    ) -> PIL.Image.Image:
        
        self.load_pipe(content_lora_model_id, style_lora_model_id, content_alpha, style_alpha)
        
        return self.inference(
            prompt=prompt,
            seed=seed,
            n_steps=n_steps,
            guidance_scale=guidance_scale,
            num_images_per_prompt=num_images_per_prompt,
        )