File size: 81,017 Bytes
6fa4eec 13af955 6fa4eec 13af955 a4a6c30 6fa4eec 13af955 6fa4eec 13af955 6fa4eec 13af955 6fa4eec a4a6c30 6fa4eec a4a6c30 13af955 a4a6c30 6fa4eec 13af955 6fa4eec a4a6c30 6fa4eec 13af955 6fa4eec 13af955 a4a6c30 6fa4eec a4a6c30 6fa4eec 441d2c3 6fa4eec a4a6c30 6fa4eec 13af955 a4a6c30 6fa4eec a4a6c30 6fa4eec a4a6c30 6fa4eec a4a6c30 441d2c3 13af955 a4a6c30 6fa4eec a4a6c30 6fa4eec a4a6c30 6fa4eec a4a6c30 6fa4eec a4a6c30 6fa4eec a4a6c30 6fa4eec a4a6c30 6fa4eec a4a6c30 6fa4eec a4a6c30 441d2c3 a4a6c30 6fa4eec 13af955 6fa4eec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 |
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# modified from https://github.com/AUTOMATIC1111/stable-diffusion-webui
# Here is the AGPL-3.0 license https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt
import copy
import inspect
import os
import os.path
import shutil
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
import paddle
import paddle.nn as nn
import PIL
import PIL.Image
from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
from paddlenlp.transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ppdiffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ppdiffusers.pipelines.pipeline_utils import DiffusionPipeline
from ppdiffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from ppdiffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from ppdiffusers.schedulers import KarrasDiffusionSchedulers
from ppdiffusers.utils import (
PIL_INTERPOLATION,
PPDIFFUSERS_CACHE,
logging,
ppdiffusers_url_download,
randn_tensor,
safetensors_load,
smart_load,
torch_load,
)
def get_civitai_download_url(display_url, url_prefix="https://civitai.com"):
if "api/download" in display_url:
return display_url
import bs4
import requests
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36 QIHU 360SE"
}
r = requests.get(display_url, headers=headers)
soup = bs4.BeautifulSoup(r.text, "lxml")
download_url = None
for a in soup.find_all("a", href=True):
if "Download" in str(a):
download_url = url_prefix + a["href"].split("?")[0]
break
return download_url
def http_file_name(
url: str,
*,
proxies=None,
headers: Optional[Dict[str, str]] = None,
timeout=10.0,
max_retries=0,
):
"""
Get a remote file name.
"""
headers = copy.deepcopy(headers) or {}
r = _request_wrapper(
method="GET",
url=url,
stream=True,
proxies=proxies,
headers=headers,
timeout=timeout,
max_retries=max_retries,
)
hf_raise_for_status(r)
displayed_name = url
content_disposition = r.headers.get("Content-Disposition")
if content_disposition is not None and "filename=" in content_disposition:
# Means file is on CDN
displayed_name = content_disposition.split("filename=")[-1]
return displayed_name
@paddle.no_grad()
def load_lora(
pipeline,
state_dict: dict,
LORA_PREFIX_UNET: str = "lora_unet",
LORA_PREFIX_TEXT_ENCODER: str = "lora_te",
ratio: float = 1.0,
):
ratio = float(ratio)
visited = []
for key in state_dict:
if ".alpha" in key or ".lora_up" in key or key in visited:
continue
if "text" in key:
tmp_layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
hf_to_ppnlp = {
"encoder": "transformer",
"fc1": "linear1",
"fc2": "linear2",
}
layer_infos = []
for layer_info in tmp_layer_infos:
if layer_info == "mlp":
continue
layer_infos.append(hf_to_ppnlp.get(layer_info, layer_info))
curr_layer: paddle.nn.Linear = pipeline.text_encoder
else:
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
curr_layer: paddle.nn.Linear = pipeline.unet
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
if temp_name == "to":
raise ValueError()
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
triplet_keys = [key, key.replace("lora_down", "lora_up"), key.replace("lora_down.weight", "alpha")]
dtype: paddle.dtype = curr_layer.weight.dtype
weight_down: paddle.Tensor = state_dict[triplet_keys[0]].cast(dtype)
weight_up: paddle.Tensor = state_dict[triplet_keys[1]].cast(dtype)
rank: float = float(weight_down.shape[0])
if triplet_keys[2] in state_dict:
alpha: float = state_dict[triplet_keys[2]].cast(dtype).item()
scale: float = alpha / rank
else:
scale = 1.0
if not hasattr(curr_layer, "backup_weights"):
curr_layer.backup_weights = curr_layer.weight.clone()
if len(weight_down.shape) == 4:
if weight_down.shape[2:4] == [1, 1]:
# conv2d 1x1
curr_layer.weight.copy_(
curr_layer.weight
+ ratio
* paddle.matmul(weight_up.squeeze([-1, -2]), weight_down.squeeze([-1, -2])).unsqueeze([-1, -2])
* scale,
True,
)
else:
# conv2d 3x3
curr_layer.weight.copy_(
curr_layer.weight
+ ratio
* paddle.nn.functional.conv2d(weight_down.transpose([1, 0, 2, 3]), weight_up).transpose(
[1, 0, 2, 3]
)
* scale,
True,
)
else:
# linear
curr_layer.weight.copy_(curr_layer.weight + ratio * paddle.matmul(weight_up, weight_down).T * scale, True)
# update visited list
visited.extend(triplet_keys)
return pipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WebUIStableDiffusionControlNetPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
controlnet ([`ControlNetModel`]):
Provides additional conditioning to the unet during the denoising process.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
enable_emphasis = True
comma_padding_backtrack = 20
LORA_DIR = os.path.join(PPDIFFUSERS_CACHE, "lora")
TI_DIR = os.path.join(PPDIFFUSERS_CACHE, "textual_inversion")
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: ControlNetModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
f"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# custom data
clip_model = FrozenCLIPEmbedder(text_encoder, tokenizer)
self.sj = StableDiffusionModelHijack(clip_model)
self.orginal_scheduler_config = self.scheduler.config
self.supported_scheduler = [
"pndm",
"lms",
"euler",
"euler-ancestral",
"dpm-multi",
"dpm-single",
"unipc-multi",
"ddim",
"ddpm",
"deis-multi",
"heun",
"kdpm2-ancestral",
"kdpm2",
]
self.weights_has_changed = False
# register_state_dict_hook to fix text_encoder, when we save_pretrained text model.
def map_to(state_dict, *args, **kwargs):
if "text_model.token_embedding.wrapped.weight" in state_dict:
state_dict["text_model.token_embedding.weight"] = state_dict.pop(
"text_model.token_embedding.wrapped.weight"
)
return state_dict
self.text_encoder.register_state_dict_hook(map_to)
def add_ti_embedding_dir(self, embeddings_dir=None):
self.sj.embedding_db.add_embedding_dir(embeddings_dir)
self.sj.embedding_db.load_textual_inversion_embeddings()
def clear_ti_embedding(self):
self.sj.embedding_db.clear_embedding_dirs()
self.sj.embedding_db.load_textual_inversion_embeddings(True)
def download_civitai_lora_file(self, url):
if os.path.isfile(url):
dst = os.path.join(self.LORA_DIR, os.path.basename(url))
shutil.copyfile(url, dst)
return dst
download_url = get_civitai_download_url(url) or url
file_path = ppdiffusers_url_download(
download_url, cache_dir=self.LORA_DIR, filename=http_file_name(download_url).strip('"')
)
return file_path
def download_civitai_ti_file(self, url):
if os.path.isfile(url):
dst = os.path.join(self.TI_DIR, os.path.basename(url))
shutil.copyfile(url, dst)
return dst
download_url = get_civitai_download_url(url) or url
file_path = ppdiffusers_url_download(
download_url, cache_dir=self.TI_DIR, filename=http_file_name(download_url).strip('"')
)
return file_path
def change_scheduler(self, scheduler_type="ddim"):
self.switch_scheduler(scheduler_type)
def switch_scheduler(self, scheduler_type="ddim"):
scheduler_type = scheduler_type.lower()
from ppdiffusers import (
DDIMScheduler,
DDPMScheduler,
DEISMultistepScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UniPCMultistepScheduler,
)
if scheduler_type == "pndm":
scheduler = PNDMScheduler.from_config(self.orginal_scheduler_config, skip_prk_steps=True)
elif scheduler_type == "lms":
scheduler = LMSDiscreteScheduler.from_config(self.orginal_scheduler_config)
elif scheduler_type == "heun":
scheduler = HeunDiscreteScheduler.from_config(self.orginal_scheduler_config)
elif scheduler_type == "euler":
scheduler = EulerDiscreteScheduler.from_config(self.orginal_scheduler_config)
elif scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler.from_config(self.orginal_scheduler_config)
elif scheduler_type == "dpm-multi":
scheduler = DPMSolverMultistepScheduler.from_config(self.orginal_scheduler_config)
elif scheduler_type == "dpm-single":
scheduler = DPMSolverSinglestepScheduler.from_config(self.orginal_scheduler_config)
elif scheduler_type == "kdpm2-ancestral":
scheduler = KDPM2AncestralDiscreteScheduler.from_config(self.orginal_scheduler_config)
elif scheduler_type == "kdpm2":
scheduler = KDPM2DiscreteScheduler.from_config(self.orginal_scheduler_config)
elif scheduler_type == "unipc-multi":
scheduler = UniPCMultistepScheduler.from_config(self.orginal_scheduler_config)
elif scheduler_type == "ddim":
scheduler = DDIMScheduler.from_config(
self.orginal_scheduler_config,
steps_offset=1,
clip_sample=False,
set_alpha_to_one=False,
)
elif scheduler_type == "ddpm":
scheduler = DDPMScheduler.from_config(
self.orginal_scheduler_config,
)
elif scheduler_type == "deis-multi":
scheduler = DEISMultistepScheduler.from_config(
self.orginal_scheduler_config,
)
else:
raise ValueError(
f"Scheduler of type {scheduler_type} doesn't exist! Please choose in {self.supported_scheduler}!"
)
self.scheduler = scheduler
@paddle.no_grad()
def _encode_prompt(
self,
prompt: str,
do_classifier_free_guidance: float = 7.5,
negative_prompt: str = None,
num_inference_steps: int = 50,
):
if do_classifier_free_guidance:
assert isinstance(negative_prompt, str)
negative_prompt = [negative_prompt]
uc = get_learned_conditioning(self.sj.clip, negative_prompt, num_inference_steps)
else:
uc = None
c = get_multicond_learned_conditioning(self.sj.clip, prompt, num_inference_steps)
return c, uc
def run_safety_checker(self, image, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pd")
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.cast(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clip(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.transpose([0, 2, 3, 1]).cast("float32").numpy()
return image
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
image,
height,
width,
callback_steps,
negative_prompt=None,
controlnet_conditioning_scale=1.0,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and not isinstance(prompt, str):
raise ValueError(f"`prompt` has to be of type `str` but is {type(prompt)}")
if negative_prompt is not None and not isinstance(negative_prompt, str):
raise ValueError(f"`negative_prompt` has to be of type `str` but is {type(negative_prompt)}")
# Check `image`
if isinstance(self.controlnet, ControlNetModel):
self.check_image(image, prompt)
else:
assert False
# Check `controlnet_conditioning_scale`
if isinstance(self.controlnet, ControlNetModel):
if not isinstance(controlnet_conditioning_scale, (float, list, tuple)):
raise TypeError(
"For single controlnet: `controlnet_conditioning_scale` must be type `float, list(float) or tuple(float)`."
)
def check_image(self, image, prompt):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, paddle.Tensor)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], paddle.Tensor)
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
raise TypeError(
"image must be one of PIL image, paddle tensor, list of PIL images, or list of paddle tensors"
)
if image_is_pil:
image_batch_size = 1
elif image_is_tensor:
image_batch_size = image.shape[0]
elif image_is_pil_list:
image_batch_size = len(image)
elif image_is_tensor_list:
image_batch_size = len(image)
if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)
def prepare_image(self, image, width, height, dtype):
if not isinstance(image, paddle.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = np.concatenate(images, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = paddle.to_tensor(image)
elif isinstance(image[0], paddle.Tensor):
image = paddle.concat(image, axis=0)
image = image.cast(dtype)
return image
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
shape = [batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor]
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def _default_height_width(self, height, width, image):
while isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, paddle.Tensor):
height = image.shape[3]
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, paddle.Tensor):
width = image.shape[2]
width = (width // 8) * 8 # round down to nearest multiple of 8
return height, width
@paddle.no_grad()
def __call__(
self,
prompt: str = None,
image: PIL.Image.Image = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: str = None,
eta: float = 0.0,
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
latents: Optional[paddle.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: int = 1,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
enable_lora: bool = True,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`paddle.Tensor`, `PIL.Image.Image`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `paddle.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
specified in init, images must be passed as a list such that each element of the list can be correctly
batched for input to a single controlnet.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
One or a list of paddle generator(s) to make generation deterministic.
latents (`paddle.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
clip_skip (`int`, *optional*, defaults to 1):
CLIP_stop_at_last_layers, if clip_skip <= 1, we will use the last_hidden_state from text_encoder.
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
corresponding scale as a list.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
self.add_ti_embedding_dir(self.TI_DIR)
try:
# 0. Default height and width to unet
height, width = self._default_height_width(height, width, image)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
image,
height,
width,
callback_steps,
negative_prompt,
controlnet_conditioning_scale,
)
batch_size = 1
image = self.prepare_image(
image=image,
width=width,
height=height,
dtype=self.controlnet.dtype,
)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
prompts, extra_network_data = parse_prompts([prompt])
if enable_lora and self.LORA_DIR is not None:
if os.path.exists(self.LORA_DIR):
lora_mapping = {p.stem: p.absolute() for p in Path(self.LORA_DIR).glob("*.safetensors")}
for params in extra_network_data["lora"]:
assert len(params.items) > 0
name = params.items[0]
if name in lora_mapping:
ratio = float(params.items[1]) if len(params.items) > 1 else 1.0
lora_state_dict = smart_load(lora_mapping[name], map_location=paddle.get_device())
self.weights_has_changed = True
load_lora(self, state_dict=lora_state_dict, ratio=ratio)
del lora_state_dict
else:
print(f"We can't find lora weight: {name}! Please make sure that exists!")
else:
if len(extra_network_data["lora"]) > 0:
print(f"{self.LORA_DIR} not exists, so we cant load loras!")
self.sj.clip.CLIP_stop_at_last_layers = clip_skip
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self._encode_prompt(
prompts,
do_classifier_free_guidance,
negative_prompt,
num_inference_steps=num_inference_steps,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size,
num_channels_latents,
height,
width,
self.unet.dtype,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
step = i // self.scheduler.order
do_batch = False
conds_list, cond_tensor = reconstruct_multicond_batch(prompt_embeds, step)
try:
weight = conds_list[0][0][1]
except Exception:
weight = 1.0
if do_classifier_free_guidance:
uncond_tensor = reconstruct_cond_batch(negative_prompt_embeds, step)
do_batch = cond_tensor.shape[1] == uncond_tensor.shape[1]
# expand the latents if we are doing classifier free guidance
latent_model_input = paddle.concat([latents] * 2) if do_batch else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
if do_batch:
encoder_hidden_states = paddle.concat([uncond_tensor, cond_tensor])
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=paddle.concat([image, image]),
conditioning_scale=controlnet_conditioning_scale,
return_dict=False,
)
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + weight * guidance_scale * (
noise_pred_text - noise_pred_uncond
)
else:
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=cond_tensor,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
return_dict=False,
)
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=cond_tensor,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample
if do_classifier_free_guidance:
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=uncond_tensor,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
return_dict=False,
)
noise_pred_uncond = self.unet(
latent_model_input,
t,
encoder_hidden_states=uncond_tensor,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample
noise_pred = noise_pred_uncond + weight * guidance_scale * (noise_pred - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type == "latent":
image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, self.unet.dtype)
# 10. Convert to PIL
image = self.numpy_to_pil(image)
else:
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, self.unet.dtype)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
except Exception as e:
raise ValueError(e)
finally:
if enable_lora and self.weights_has_changed:
for sub_layer in self.text_encoder.sublayers(include_self=True):
if hasattr(sub_layer, "backup_weights"):
sub_layer.weight.copy_(sub_layer.backup_weights, True)
for sub_layer in self.unet.sublayers(include_self=True):
if hasattr(sub_layer, "backup_weights"):
sub_layer.weight.copy_(sub_layer.backup_weights, True)
self.weights_has_changed = False
# clip.py
import math
from collections import namedtuple
class PromptChunk:
"""
This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
so just 75 tokens from prompt.
"""
def __init__(self):
self.tokens = []
self.multipliers = []
self.fixes = []
PromptChunkFix = namedtuple("PromptChunkFix", ["offset", "embedding"])
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
class FrozenCLIPEmbedder(nn.Layer):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(self, text_encoder, tokenizer, freeze=True, layer="last", layer_idx=None):
super().__init__()
assert layer in self.LAYERS
self.tokenizer = tokenizer
self.text_encoder = text_encoder
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.text_encoder.eval()
for param in self.parameters():
param.stop_gradient = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.tokenizer.model_max_length,
padding="max_length",
return_tensors="pd",
)
tokens = batch_encoding["input_ids"]
outputs = self.text_encoder(input_ids=tokens, output_hidden_states=self.layer == "hidden", return_dict=True)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedderWithCustomWordsBase(nn.Layer):
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
have unlimited prompt length and assign weights to tokens in prompt.
"""
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
depending on model."""
self.hijack = hijack
self.chunk_length = 75
def empty_chunk(self):
"""creates an empty PromptChunk and returns it"""
chunk = PromptChunk()
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
chunk.multipliers = [1.0] * (self.chunk_length + 2)
return chunk
def get_target_prompt_token_count(self, token_count):
"""returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
def tokenize(self, texts):
"""Converts a batch of texts into a batch of token ids"""
raise NotImplementedError
def encode_with_text_encoder(self, tokens):
"""
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
All python lists with tokens are assumed to have same length, usually 77.
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
model - can be 768 and 1024.
Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
"""
raise NotImplementedError
def encode_embedding_init_text(self, init_text, nvpt):
"""Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
raise NotImplementedError
def tokenize_line(self, line):
"""
this transforms a single prompt into a list of PromptChunk objects - as many as needed to
represent the prompt.
Returns the list and the total number of tokens in the prompt.
"""
if WebUIStableDiffusionControlNetPipeline.enable_emphasis:
parsed = parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.tokenize([text for text, _ in parsed])
chunks = []
chunk = PromptChunk()
token_count = 0
last_comma = -1
def next_chunk(is_last=False):
"""puts current chunk into the list of results and produces the next one - empty;
if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
nonlocal token_count
nonlocal last_comma
nonlocal chunk
if is_last:
token_count += len(chunk.tokens)
else:
token_count += self.chunk_length
to_add = self.chunk_length - len(chunk.tokens)
if to_add > 0:
chunk.tokens += [self.id_end] * to_add
chunk.multipliers += [1.0] * to_add
chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
last_comma = -1
chunks.append(chunk)
chunk = PromptChunk()
for tokens, (text, weight) in zip(tokenized, parsed):
if text == "BREAK" and weight == -1:
next_chunk()
continue
position = 0
while position < len(tokens):
token = tokens[position]
if token == self.comma_token:
last_comma = len(chunk.tokens)
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
elif (
WebUIStableDiffusionControlNetPipeline.comma_padding_backtrack != 0
and len(chunk.tokens) == self.chunk_length
and last_comma != -1
and len(chunk.tokens) - last_comma
<= WebUIStableDiffusionControlNetPipeline.comma_padding_backtrack
):
break_location = last_comma + 1
reloc_tokens = chunk.tokens[break_location:]
reloc_mults = chunk.multipliers[break_location:]
chunk.tokens = chunk.tokens[:break_location]
chunk.multipliers = chunk.multipliers[:break_location]
next_chunk()
chunk.tokens = reloc_tokens
chunk.multipliers = reloc_mults
if len(chunk.tokens) == self.chunk_length:
next_chunk()
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(
tokens, position
)
if embedding is None:
chunk.tokens.append(token)
chunk.multipliers.append(weight)
position += 1
continue
emb_len = int(embedding.vec.shape[0])
if len(chunk.tokens) + emb_len > self.chunk_length:
next_chunk()
chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
chunk.tokens += [0] * emb_len
chunk.multipliers += [weight] * emb_len
position += embedding_length_in_tokens
if len(chunk.tokens) > 0 or len(chunks) == 0:
next_chunk(is_last=True)
return chunks, token_count
def process_texts(self, texts):
"""
Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
length, in tokens, of all texts.
"""
token_count = 0
cache = {}
batch_chunks = []
for line in texts:
if line in cache:
chunks = cache[line]
else:
chunks, current_token_count = self.tokenize_line(line)
token_count = max(current_token_count, token_count)
cache[line] = chunks
batch_chunks.append(chunks)
return batch_chunks, token_count
def forward(self, texts):
"""
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
An example shape returned by this function can be: (2, 77, 768).
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
"""
batch_chunks, token_count = self.process_texts(texts)
used_embeddings = {}
chunk_count = max([len(x) for x in batch_chunks])
zs = []
for i in range(chunk_count):
batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
tokens = [x.tokens for x in batch_chunk]
multipliers = [x.multipliers for x in batch_chunk]
self.hijack.fixes = [x.fixes for x in batch_chunk]
for fixes in self.hijack.fixes:
for position, embedding in fixes:
used_embeddings[embedding.name] = embedding
z = self.process_tokens(tokens, multipliers)
zs.append(z)
if len(used_embeddings) > 0:
embeddings_list = ", ".join(
[f"{name} [{embedding.checksum()}]" for name, embedding in used_embeddings.items()]
)
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
return paddle.concat(zs, axis=1)
def process_tokens(self, remade_batch_tokens, batch_multipliers):
"""
sends one single prompt chunk to be encoded by transformers neural network.
remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
corresponds to one token.
"""
tokens = paddle.to_tensor(remade_batch_tokens)
# this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
if self.id_end != self.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(self.id_end)
tokens[batch_pos, index + 1 : tokens.shape[1]] = self.id_pad
z = self.encode_with_text_encoder(tokens)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = paddle.to_tensor(batch_multipliers)
original_mean = z.mean()
z = z * batch_multipliers.reshape(
batch_multipliers.shape
+ [
1,
]
).expand(z.shape)
new_mean = z.mean()
z = z * (original_mean / new_mean)
return z
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack, CLIP_stop_at_last_layers=-1):
super().__init__(wrapped, hijack)
self.CLIP_stop_at_last_layers = CLIP_stop_at_last_layers
self.tokenizer = wrapped.tokenizer
vocab = self.tokenizer.get_vocab()
self.comma_token = vocab.get(",</w>", None)
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in vocab.items() if "(" in k or ")" in k or "[" in k or "]" in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == "[":
mult /= 1.1
if c == "]":
mult *= 1.1
if c == "(":
mult *= 1.1
if c == ")":
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
self.id_start = self.wrapped.tokenizer.bos_token_id
self.id_end = self.wrapped.tokenizer.eos_token_id
self.id_pad = self.id_end
def tokenize(self, texts):
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
return tokenized
def encode_with_text_encoder(self, tokens):
output_hidden_states = self.CLIP_stop_at_last_layers > 1
outputs = self.wrapped.text_encoder(
input_ids=tokens, output_hidden_states=output_hidden_states, return_dict=True
)
if output_hidden_states:
z = outputs.hidden_states[-self.CLIP_stop_at_last_layers]
z = self.wrapped.text_encoder.text_model.ln_final(z)
else:
z = outputs.last_hidden_state
return z
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.text_encoder.text_model
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pd", add_special_tokens=False)[
"input_ids"
]
embedded = embedding_layer.token_embedding.wrapped(ids).squeeze(0)
return embedded
# extra_networks.py
import re
from collections import defaultdict
class ExtraNetworkParams:
def __init__(self, items=None):
self.items = items or []
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
def parse_prompt(prompt):
res = defaultdict(list)
def found(m):
name = m.group(1)
args = m.group(2)
res[name].append(ExtraNetworkParams(items=args.split(":")))
return ""
prompt = re.sub(re_extra_net, found, prompt)
return prompt, res
def parse_prompts(prompts):
res = []
extra_data = None
for prompt in prompts:
updated_prompt, parsed_extra_data = parse_prompt(prompt)
if extra_data is None:
extra_data = parsed_extra_data
res.append(updated_prompt)
return res, extra_data
# image_embeddings.py
import base64
import json
import zlib
import numpy as np
from PIL import Image
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, d):
if "TORCHTENSOR" in d:
return paddle.to_tensor(np.array(d["TORCHTENSOR"]))
return d
def embedding_from_b64(data):
d = base64.b64decode(data)
return json.loads(d, cls=EmbeddingDecoder)
def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
while True:
seed = (a * seed + c) % m
yield seed % 255
def xor_block(block):
g = lcg()
randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
def crop_black(img, tol=0):
mask = (img > tol).all(2)
mask0, mask1 = mask.any(0), mask.any(1)
col_start, col_end = mask0.argmax(), mask.shape[1] - mask0[::-1].argmax()
row_start, row_end = mask1.argmax(), mask.shape[0] - mask1[::-1].argmax()
return img[row_start:row_end, col_start:col_end]
def extract_image_data_embed(image):
d = 3
outarr = (
crop_black(np.array(image.convert("RGB").getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8))
& 0x0F
)
black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
if black_cols[0].shape[0] < 2:
print("No Image data blocks found.")
return None
data_block_lower = outarr[:, : black_cols[0].min(), :].astype(np.uint8)
data_block_upper = outarr[:, black_cols[0].max() + 1 :, :].astype(np.uint8)
data_block_lower = xor_block(data_block_lower)
data_block_upper = xor_block(data_block_upper)
data_block = (data_block_upper << 4) | (data_block_lower)
data_block = data_block.flatten().tobytes()
data = zlib.decompress(data_block)
return json.loads(data, cls=EmbeddingDecoder)
# prompt_parser.py
import re
from collections import namedtuple
from typing import List
import lark
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
# [75, 'fantasy landscape with a lake and an oak in background masterful']
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
schedule_parser = lark.Lark(
r"""
!start: (prompt | /[][():]/+)*
prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
!emphasized: "(" prompt ")"
| "(" prompt ":" prompt ")"
| "[" prompt "]"
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
alternate: "[" prompt ("|" prompt)+ "]"
WHITESPACE: /\s+/
plain: /([^\\\[\]():|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
"""
)
def get_learned_conditioning_prompt_schedules(prompts, steps):
"""
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
>>> g("test")
[[10, 'test']]
>>> g("a [b:3]")
[[3, 'a '], [10, 'a b']]
>>> g("a [b: 3]")
[[3, 'a '], [10, 'a b']]
>>> g("a [[[b]]:2]")
[[2, 'a '], [10, 'a [[b]]']]
>>> g("[(a:2):3]")
[[3, ''], [10, '(a:2)']]
>>> g("a [b : c : 1] d")
[[1, 'a b d'], [10, 'a c d']]
>>> g("a[b:[c:d:2]:1]e")
[[1, 'abe'], [2, 'ace'], [10, 'ade']]
>>> g("a [unbalanced")
[[10, 'a [unbalanced']]
>>> g("a [b:.5] c")
[[5, 'a c'], [10, 'a b c']]
>>> g("a [{b|d{:.5] c") # not handling this right now
[[5, 'a c'], [10, 'a {b|d{ c']]
>>> g("((a][:b:c [d:3]")
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
>>> g("[a|(b:1.1)]")
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
"""
def collect_steps(steps, tree):
l = [steps]
class CollectSteps(lark.Visitor):
def scheduled(self, tree):
tree.children[-1] = float(tree.children[-1])
if tree.children[-1] < 1:
tree.children[-1] *= steps
tree.children[-1] = min(steps, int(tree.children[-1]))
l.append(tree.children[-1])
def alternate(self, tree):
l.extend(range(1, steps + 1))
CollectSteps().visit(tree)
return sorted(set(l))
def at_step(step, tree):
class AtStep(lark.Transformer):
def scheduled(self, args):
before, after, _, when = args
yield before or () if step <= when else after
def alternate(self, args):
yield next(args[(step - 1) % len(args)])
def start(self, args):
def flatten(x):
if type(x) == str:
yield x
else:
for gen in x:
yield from flatten(gen)
return "".join(flatten(args))
def plain(self, args):
yield args[0].value
def __default__(self, data, children, meta):
for child in children:
yield child
return AtStep().transform(tree)
def get_schedule(prompt):
try:
tree = schedule_parser.parse(prompt)
except lark.exceptions.LarkError:
if 0:
import traceback
traceback.print_exc()
return [[steps, prompt]]
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
return [promptdict[prompt] for prompt in prompts]
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
def get_learned_conditioning(model, prompts, steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
Input:
(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
Output:
[
[
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
],
[
ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
]
]
"""
res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
cache = {}
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
cached = cache.get(prompt, None)
if cached is not None:
res.append(cached)
continue
texts = [x[1] for x in prompt_schedule]
conds = model(texts)
cond_schedule = []
for i, (end_at_step, text) in enumerate(prompt_schedule):
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
cache[prompt] = cond_schedule
res.append(cond_schedule)
return res
re_AND = re.compile(r"\bAND\b")
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
def get_multicond_prompt_list(prompts):
res_indexes = []
prompt_flat_list = []
prompt_indexes = {}
for prompt in prompts:
subprompts = re_AND.split(prompt)
indexes = []
for subprompt in subprompts:
match = re_weight.search(subprompt)
text, weight = match.groups() if match is not None else (subprompt, 1.0)
weight = float(weight) if weight is not None else 1.0
index = prompt_indexes.get(text, None)
if index is None:
index = len(prompt_flat_list)
prompt_flat_list.append(text)
prompt_indexes[text] = index
indexes.append((index, weight))
res_indexes.append(indexes)
return res_indexes, prompt_flat_list, prompt_indexes
class ComposableScheduledPromptConditioning:
def __init__(self, schedules, weight=1.0):
self.schedules: List[ScheduledPromptConditioning] = schedules
self.weight: float = weight
class MulticondLearnedConditioning:
def __init__(self, shape, batch):
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
"""
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
res = []
for indexes in res_indexes:
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
res = paddle.zeros(
[
len(c),
]
+ param.shape,
dtype=param.dtype,
)
for i, cond_schedule in enumerate(c):
target_index = 0
for current, (end_at, cond) in enumerate(cond_schedule):
if current_step <= end_at:
target_index = current
break
res[i] = cond_schedule[target_index].cond
return res
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
param = c.batch[0][0].schedules[0].cond
tensors = []
conds_list = []
for batch_no, composable_prompts in enumerate(c.batch):
conds_for_batch = []
for cond_index, composable_prompt in enumerate(composable_prompts):
target_index = 0
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
if current_step <= end_at:
target_index = current
break
conds_for_batch.append((len(tensors), composable_prompt.weight))
tensors.append(composable_prompt.schedules[target_index].cond)
conds_list.append(conds_for_batch)
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
# and won't be able to torch.stack them. So this fixes that.
token_count = max([x.shape[0] for x in tensors])
for i in range(len(tensors)):
if tensors[i].shape[0] != token_count:
last_vector = tensors[i][-1:]
last_vector_repeated = last_vector.tile([token_count - tensors[i].shape[0], 1])
tensors[i] = paddle.concat([tensors[i], last_vector_repeated], axis=0)
return conds_list, paddle.stack(tensors).cast(dtype=param.dtype)
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
parts = re.split(re_break, text)
for i, part in enumerate(parts):
if i > 0:
res.append(["BREAK", -1])
res.append([part, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
# sd_hijack.py
class StableDiffusionModelHijack:
fixes = None
comments = []
layers = None
circular_enabled = False
def __init__(self, clip_model, embeddings_dir=None, CLIP_stop_at_last_layers=-1):
model_embeddings = clip_model.text_encoder.text_model
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
clip_model = FrozenCLIPEmbedderWithCustomWords(
clip_model, self, CLIP_stop_at_last_layers=CLIP_stop_at_last_layers
)
self.embedding_db = EmbeddingDatabase(clip_model)
self.embedding_db.add_embedding_dir(embeddings_dir)
# hack this!
self.clip = clip_model
def flatten(el):
flattened = [flatten(children) for children in el.children()]
res = [el]
for c in flattened:
res += c
return res
self.layers = flatten(clip_model)
def clear_comments(self):
self.comments = []
def get_prompt_lengths(self, text):
_, token_count = self.clip.process_texts([text])
return token_count, self.clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(nn.Layer):
def __init__(self, wrapped, embeddings):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
self.embeddings.fixes = None
inputs_embeds = self.wrapped(input_ids)
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
return inputs_embeds
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec.cast(self.wrapped.dtype)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = paddle.concat([tensor[0 : offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len :]])
vecs.append(tensor)
return paddle.stack(vecs)
# textual_inversion.py
import os
import sys
import traceback
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
self.name = name
self.step = step
self.shape = None
self.vectors = 0
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.optimizer_state_dict = None
self.filename = None
def save(self, filename):
embedding_data = {
"string_to_token": {"*": 265},
"string_to_param": {"*": self.vec},
"name": self.name,
"step": self.step,
"sd_checkpoint": self.sd_checkpoint,
"sd_checkpoint_name": self.sd_checkpoint_name,
}
paddle.save(embedding_data, filename)
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
self.cached_checksum = f"{const_hash(self.vec.flatten() * 100) & 0xffff:04x}"
return self.cached_checksum
class DirWithTextualInversionEmbeddings:
def __init__(self, path):
self.path = path
self.mtime = None
def has_changed(self):
if not os.path.isdir(self.path):
return False
mt = os.path.getmtime(self.path)
if self.mtime is None or mt > self.mtime:
return True
def update(self):
if not os.path.isdir(self.path):
return
self.mtime = os.path.getmtime(self.path)
class EmbeddingDatabase:
def __init__(self, clip):
self.clip = clip
self.ids_lookup = {}
self.word_embeddings = {}
self.skipped_embeddings = {}
self.expected_shape = -1
self.embedding_dirs = {}
self.previously_displayed_embeddings = ()
def add_embedding_dir(self, path):
if path is not None and path not in self.embedding_dirs:
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
def clear_embedding_dirs(self):
self.embedding_dirs.clear()
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.tokenize([embedding.name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id] = sorted(
self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True
)
return embedding
def get_expected_shape(self):
vec = self.clip.encode_embedding_init_text(",", 1)
return vec.shape[1]
def load_from_file(self, path, filename):
name, ext = os.path.splitext(filename)
ext = ext.upper()
if ext in [".PNG", ".WEBP", ".JXL", ".AVIF"]:
_, second_ext = os.path.splitext(name)
if second_ext.upper() == ".PREVIEW":
return
embed_image = Image.open(path)
if hasattr(embed_image, "text") and "sd-ti-embedding" in embed_image.text:
data = embedding_from_b64(embed_image.text["sd-ti-embedding"])
name = data.get("name", name)
else:
data = extract_image_data_embed(embed_image)
if data:
name = data.get("name", name)
else:
# if data is None, means this is not an embeding, just a preview image
return
elif ext in [".BIN", ".PT"]:
data = torch_load(path)
elif ext in [".SAFETENSORS"]:
data = safetensors_load(path)
else:
return
# textual inversion embeddings
if "string_to_param" in data:
param_dict = data["string_to_param"]
if hasattr(param_dict, "_parameters"):
param_dict = getattr(param_dict, "_parameters")
assert len(param_dict) == 1, "embedding file has multiple terms in it"
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == paddle.Tensor:
assert len(data.keys()) == 1, "embedding file has multiple terms in it"
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise Exception(
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept."
)
with paddle.no_grad():
if hasattr(emb, "detach"):
emb = emb.detach()
if hasattr(emb, "cpu"):
emb = emb.cpu()
if hasattr(emb, "numpy"):
emb = emb.numpy()
emb = paddle.to_tensor(emb)
vec = emb.detach().cast(paddle.float32)
embedding = Embedding(vec, name)
embedding.step = data.get("step", None)
embedding.sd_checkpoint = data.get("sd_checkpoint", None)
embedding.sd_checkpoint_name = data.get("sd_checkpoint_name", None)
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
embedding.filename = path
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, self.clip)
else:
self.skipped_embeddings[name] = embedding
def load_from_dir(self, embdir):
if not os.path.isdir(embdir.path):
return
for root, dirs, fns in os.walk(embdir.path, followlinks=True):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
if os.stat(fullfn).st_size == 0:
continue
self.load_from_file(fullfn, fn)
except Exception:
print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
def load_textual_inversion_embeddings(self, force_reload=False):
if not force_reload:
need_reload = False
for path, embdir in self.embedding_dirs.items():
if embdir.has_changed():
need_reload = True
break
if not need_reload:
return
self.ids_lookup.clear()
self.word_embeddings.clear()
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
for path, embdir in self.embedding_dirs.items():
self.load_from_dir(embdir)
embdir.update()
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
if self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings
print(
f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}"
)
if len(self.skipped_embeddings) > 0:
print(
f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}"
)
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
return None, None
for ids, embedding in possible_matches:
if tokens[offset : offset + len(ids)] == ids:
return embedding, len(ids)
return None, None
|