Spaces:
Sleeping
Sleeping
File size: 5,007 Bytes
ddea0a0 4c3e795 ddea0a0 496c3d0 2ad93a6 ddea0a0 2ad93a6 ddea0a0 698aefd ddea0a0 698aefd ddea0a0 078397a 4c3e795 078397a ddea0a0 77e3547 ddea0a0 fa662f9 ddea0a0 77e3547 ddea0a0 77e3547 078397a 5db164b 078397a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from __future__ import annotations
import gc
import pathlib
import spaces
import gradio as gr
import PIL.Image
import torch
from diffusers import StableDiffusionXLPipeline
from huggingface_hub import ModelCard
from blora_utils import BLOCKS, filter_lora, scale_lora
class InferencePipeline:
def __init__(self, hf_token: str | None = None):
self.hf_token = hf_token
self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.pipe = StableDiffusionXLPipeline.from_pretrained(
self.base_model_id,
torch_dtype=torch.float16,
use_auth_token=self.hf_token)
self.content_lora_model_id = None
self.style_lora_model_id = None
def clear(self) -> None:
self.content_lora_model_id = None
self.style_lora_model_id = None
del self.pipe
self.pipe = None
torch.cuda.empty_cache()
gc.collect()
def load_b_lora_to_unet(self, content_lora_model_id: str, style_lora_model_id: str, content_alpha: float,
style_alpha: float) -> None:
try:
# Get Content B-LoRA SD
if content_lora_model_id and content_lora_model_id != 'None':
content_B_LoRA_sd, _ = self.pipe.lora_state_dict(content_lora_model_id, use_auth_token=self.hf_token)
content_B_LoRA = filter_lora(content_B_LoRA_sd, BLOCKS['content'])
content_B_LoRA = scale_lora(content_B_LoRA, content_alpha)
else:
content_B_LoRA = {}
# Get Style B-LoRA SD
if style_lora_model_id and style_lora_model_id != 'None':
style_B_LoRA_sd, _ = self.pipe.lora_state_dict(style_lora_model_id, use_auth_token=self.hf_token)
style_B_LoRA = filter_lora(style_B_LoRA_sd, BLOCKS['style'])
style_B_LoRA = scale_lora(style_B_LoRA, style_alpha)
else:
style_B_LoRA = {}
# Merge B-LoRAs SD
res_lora = {**content_B_LoRA, **style_B_LoRA}
# Load
self.pipe.load_lora_into_unet(res_lora, None, self.pipe.unet)
except Exception as e:
raise type(e)(f'failed to load_b_lora_to_unet, due to: {e}')
@staticmethod
def check_if_model_is_local(lora_model_id: str) -> bool:
return pathlib.Path(lora_model_id).exists()
@staticmethod
def get_model_card(model_id: str,
hf_token: str | None = None) -> ModelCard:
if InferencePipeline.check_if_model_is_local(model_id):
card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
else:
card_path = model_id
return ModelCard.load(card_path, token=hf_token)
@staticmethod
def get_base_model_info(lora_model_id: str,
hf_token: str | None = None) -> str:
card = InferencePipeline.get_model_card(lora_model_id, hf_token)
return card.data.base_model
def load_pipe(self, content_lora_model_id: str, style_lora_model_id: str, content_alpha: float,
style_alpha: float) -> None:
if content_lora_model_id == self.content_lora_model_id and style_lora_model_id == self.style_lora_model_id:
return
self.pipe.unload_lora_weights()
self.load_b_lora_to_unet(content_lora_model_id, style_lora_model_id, content_alpha, style_alpha)
self.content_lora_model_id = content_lora_model_id
self.style_lora_model_id = style_lora_model_id
@spaces.GPU
def inference(self,
prompt: str,
seed: int,
n_steps: int,
guidance_scale: float,
num_images_per_prompt: int = 1
) -> PIL.Image.Image:
if not torch.cuda.is_available():
raise gr.Error('CUDA is not available.')
self.pipe.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(seed)
out = self.pipe(
prompt,
num_inference_steps=n_steps,
guidance_scale=guidance_scale,
generator=generator,
num_images_per_prompt=num_images_per_prompt,
) # type: ignore
return out.images
def run(
self,
content_lora_model_id: str,
style_lora_model_id: str,
prompt: str,
content_alpha: float,
style_alpha: float,
seed: int,
n_steps: int,
guidance_scale: float,
num_images_per_prompt: int = 1
) -> PIL.Image.Image:
self.load_pipe(content_lora_model_id, style_lora_model_id, content_alpha, style_alpha)
return self.inference(
prompt=prompt,
seed=seed,
n_steps=n_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
)
|