junnyu commited on
Commit
13af955
1 Parent(s): eecac3d

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +114 -18
pipeline.py CHANGED
@@ -16,7 +16,11 @@
16
  # modified from https://github.com/AUTOMATIC1111/stable-diffusion-webui
17
  # Here is the AGPL-3.0 license https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt
18
 
 
19
  import inspect
 
 
 
20
  from pathlib import Path
21
  from typing import Any, Callable, Dict, List, Optional, Union
22
 
@@ -24,6 +28,7 @@ import paddle
24
  import paddle.nn as nn
25
  import PIL
26
  import PIL.Image
 
27
 
28
  from paddlenlp.transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
29
  from ppdiffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
@@ -35,7 +40,9 @@ from ppdiffusers.pipelines.stable_diffusion.safety_checker import (
35
  from ppdiffusers.schedulers import KarrasDiffusionSchedulers
36
  from ppdiffusers.utils import (
37
  PIL_INTERPOLATION,
 
38
  logging,
 
39
  randn_tensor,
40
  safetensors_load,
41
  smart_load,
@@ -43,6 +50,55 @@ from ppdiffusers.utils import (
43
  )
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  @paddle.no_grad()
47
  def load_lora(
48
  pipeline,
@@ -168,6 +224,9 @@ class WebUIStableDiffusionControlNetPipeline(DiffusionPipeline):
168
  enable_emphasis = True
169
  comma_padding_backtrack = 20
170
 
 
 
 
171
  def __init__(
172
  self,
173
  vae: AutoencoderKL,
@@ -232,7 +291,17 @@ class WebUIStableDiffusionControlNetPipeline(DiffusionPipeline):
232
  ]
233
  self.weights_has_changed = False
234
 
235
- def add_ti_embedding_dir(self, embeddings_dir):
 
 
 
 
 
 
 
 
 
 
236
  self.sj.embedding_db.add_embedding_dir(embeddings_dir)
237
  self.sj.embedding_db.load_textual_inversion_embeddings()
238
 
@@ -240,6 +309,30 @@ class WebUIStableDiffusionControlNetPipeline(DiffusionPipeline):
240
  self.sj.embedding_db.clear_embedding_dirs()
241
  self.sj.embedding_db.load_textual_inversion_embeddings(True)
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  def change_scheduler(self, scheduler_type="ddim"):
244
  self.switch_scheduler(scheduler_type)
245
 
@@ -507,7 +600,6 @@ class WebUIStableDiffusionControlNetPipeline(DiffusionPipeline):
507
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
508
  clip_skip: int = 1,
509
  controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
510
- lora_dir: str = "./loras",
511
  ):
512
  r"""
513
  Function invoked when calling the pipeline for generation.
@@ -571,8 +663,6 @@ class WebUIStableDiffusionControlNetPipeline(DiffusionPipeline):
571
  The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
572
  to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
573
  corresponding scale as a list.
574
- lora_dir (`str`, *optional*):
575
- Path to lora which we want to load.
576
  Examples:
577
 
578
  Returns:
@@ -582,6 +672,8 @@ class WebUIStableDiffusionControlNetPipeline(DiffusionPipeline):
582
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
583
  (nsfw) content, according to the `safety_checker`.
584
  """
 
 
585
  try:
586
  # 0. Default height and width to unet
587
  height, width = self._default_height_width(height, width, image)
@@ -613,19 +705,23 @@ class WebUIStableDiffusionControlNetPipeline(DiffusionPipeline):
613
 
614
  prompts, extra_network_data = parse_prompts([prompt])
615
 
616
- if lora_dir is not None and os.path.exists(lora_dir):
617
- lora_mapping = {p.stem: p.absolute() for p in Path(lora_dir).glob("*.safetensors")}
618
- for params in extra_network_data["lora"]:
619
- assert len(params.items) > 0
620
- name = params.items[0]
621
- if name in lora_mapping:
622
- ratio = float(params.items[1]) if len(params.items) > 1 else 1.0
623
- lora_state_dict = smart_load(lora_mapping[name], map_location=paddle.get_device())
624
- self.weights_has_changed = True
625
- load_lora(self, state_dict=lora_state_dict, ratio=ratio)
626
- del lora_state_dict
627
- else:
628
- print(f"We can't find lora weight: {name}! Please make sure that exists!")
 
 
 
 
629
 
630
  self.sj.clip.CLIP_stop_at_last_layers = clip_skip
631
  # 3. Encode input prompt
@@ -1808,7 +1904,7 @@ class EmbeddingDatabase:
1808
  self.previously_displayed_embeddings = ()
1809
 
1810
  def add_embedding_dir(self, path):
1811
- if path is not None:
1812
  self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
1813
 
1814
  def clear_embedding_dirs(self):
 
16
  # modified from https://github.com/AUTOMATIC1111/stable-diffusion-webui
17
  # Here is the AGPL-3.0 license https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt
18
 
19
+ import copy
20
  import inspect
21
+ import os
22
+ import os.path
23
+ import shutil
24
  from pathlib import Path
25
  from typing import Any, Callable, Dict, List, Optional, Union
26
 
 
28
  import paddle.nn as nn
29
  import PIL
30
  import PIL.Image
31
+ from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
32
 
33
  from paddlenlp.transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
34
  from ppdiffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
 
40
  from ppdiffusers.schedulers import KarrasDiffusionSchedulers
41
  from ppdiffusers.utils import (
42
  PIL_INTERPOLATION,
43
+ PPDIFFUSERS_CACHE,
44
  logging,
45
+ ppdiffusers_url_download,
46
  randn_tensor,
47
  safetensors_load,
48
  smart_load,
 
50
  )
51
 
52
 
53
+ def get_civitai_download_url(display_url, url_prefix="https://civitai.com"):
54
+ if "api/download" in display_url:
55
+ return display_url
56
+ import bs4
57
+ import requests
58
+
59
+ headers = {
60
+ "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36 QIHU 360SE"
61
+ }
62
+ r = requests.get(display_url, headers=headers)
63
+ soup = bs4.BeautifulSoup(r.text, "lxml")
64
+ download_url = None
65
+ for a in soup.find_all("a", href=True):
66
+ if "Download" in str(a):
67
+ download_url = url_prefix + a["href"].split("?")[0]
68
+ break
69
+ return download_url
70
+
71
+
72
+ def http_file_name(
73
+ url: str,
74
+ *,
75
+ proxies=None,
76
+ headers: Optional[Dict[str, str]] = None,
77
+ timeout=10.0,
78
+ max_retries=0,
79
+ ):
80
+ """
81
+ Get a remote file name.
82
+ """
83
+ headers = copy.deepcopy(headers) or {}
84
+ r = _request_wrapper(
85
+ method="GET",
86
+ url=url,
87
+ stream=True,
88
+ proxies=proxies,
89
+ headers=headers,
90
+ timeout=timeout,
91
+ max_retries=max_retries,
92
+ )
93
+ hf_raise_for_status(r)
94
+ displayed_name = url
95
+ content_disposition = r.headers.get("Content-Disposition")
96
+ if content_disposition is not None and "filename=" in content_disposition:
97
+ # Means file is on CDN
98
+ displayed_name = content_disposition.split("filename=")[-1]
99
+ return displayed_name
100
+
101
+
102
  @paddle.no_grad()
103
  def load_lora(
104
  pipeline,
 
224
  enable_emphasis = True
225
  comma_padding_backtrack = 20
226
 
227
+ LORA_DIR = os.path.join(PPDIFFUSERS_CACHE, "lora")
228
+ TI_DIR = os.path.join(PPDIFFUSERS_CACHE, "textual_inversion")
229
+
230
  def __init__(
231
  self,
232
  vae: AutoencoderKL,
 
291
  ]
292
  self.weights_has_changed = False
293
 
294
+ # register_state_dict_hook to fix text_encoder, when we save_pretrained text model.
295
+ def map_to(state_dict, *args, **kwargs):
296
+ if "text_model.token_embedding.wrapped.weight" in state_dict:
297
+ state_dict["text_model.token_embedding.weight"] = state_dict.pop(
298
+ "text_model.token_embedding.wrapped.weight"
299
+ )
300
+ return state_dict
301
+
302
+ self.text_encoder.register_state_dict_hook(map_to)
303
+
304
+ def add_ti_embedding_dir(self, embeddings_dir=None):
305
  self.sj.embedding_db.add_embedding_dir(embeddings_dir)
306
  self.sj.embedding_db.load_textual_inversion_embeddings()
307
 
 
309
  self.sj.embedding_db.clear_embedding_dirs()
310
  self.sj.embedding_db.load_textual_inversion_embeddings(True)
311
 
312
+ def download_civitai_lora_file(self, url):
313
+ if os.path.isfile(url):
314
+ dst = os.path.join(self.LORA_DIR, os.path.basename(url))
315
+ shutil.copyfile(url, dst)
316
+ return dst
317
+
318
+ download_url = get_civitai_download_url(url) or url
319
+ file_path = ppdiffusers_url_download(
320
+ download_url, cache_dir=self.LORA_DIR, filename=http_file_name(download_url).strip('"')
321
+ )
322
+ return file_path
323
+
324
+ def download_civitai_ti_file(self, url):
325
+ if os.path.isfile(url):
326
+ dst = os.path.join(self.TI_DIR, os.path.basename(url))
327
+ shutil.copyfile(url, dst)
328
+ return dst
329
+
330
+ download_url = get_civitai_download_url(url) or url
331
+ file_path = ppdiffusers_url_download(
332
+ download_url, cache_dir=self.TI_DIR, filename=http_file_name(download_url).strip('"')
333
+ )
334
+ return file_path
335
+
336
  def change_scheduler(self, scheduler_type="ddim"):
337
  self.switch_scheduler(scheduler_type)
338
 
 
600
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
601
  clip_skip: int = 1,
602
  controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
 
603
  ):
604
  r"""
605
  Function invoked when calling the pipeline for generation.
 
663
  The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
664
  to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
665
  corresponding scale as a list.
 
 
666
  Examples:
667
 
668
  Returns:
 
672
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
673
  (nsfw) content, according to the `safety_checker`.
674
  """
675
+ self.add_ti_embedding_dir(self.TI_DIR)
676
+
677
  try:
678
  # 0. Default height and width to unet
679
  height, width = self._default_height_width(height, width, image)
 
705
 
706
  prompts, extra_network_data = parse_prompts([prompt])
707
 
708
+ if self.LORA_DIR is not None:
709
+ if os.path.exists(self.LORA_DIR):
710
+ lora_mapping = {p.stem: p.absolute() for p in Path(self.LORA_DIR).glob("*.safetensors")}
711
+ for params in extra_network_data["lora"]:
712
+ assert len(params.items) > 0
713
+ name = params.items[0]
714
+ if name in lora_mapping:
715
+ ratio = float(params.items[1]) if len(params.items) > 1 else 1.0
716
+ lora_state_dict = smart_load(lora_mapping[name], map_location=paddle.get_device())
717
+ self.weights_has_changed = True
718
+ load_lora(self, state_dict=lora_state_dict, ratio=ratio)
719
+ del lora_state_dict
720
+ else:
721
+ print(f"We can't find lora weight: {name}! Please make sure that exists!")
722
+ else:
723
+ if len(extra_network_data["lora"]) > 0:
724
+ print(f"{self.LORA_DIR} not exists, so we cant load loras!")
725
 
726
  self.sj.clip.CLIP_stop_at_last_layers = clip_skip
727
  # 3. Encode input prompt
 
1904
  self.previously_displayed_embeddings = ()
1905
 
1906
  def add_embedding_dir(self, path):
1907
+ if path is not None and path not in self.embedding_dirs:
1908
  self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
1909
 
1910
  def clear_embedding_dirs(self):