shaokun's picture
WIP
a4d7b31
from typing import List
import torch
from safetensors import safe_open
from diffusers import StableDiffusionPipeline
from .lora import (
monkeypatch_or_replace_safeloras,
apply_learned_embed_in_clip,
set_lora_diag,
parse_safeloras_embeds,
)
def lora_join(lora_safetenors: list):
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
_total_metadata = {}
total_metadata = {}
total_tensor = {}
total_rank = 0
ranklist = []
for _metadata in metadatas:
rankset = []
for k, v in _metadata.items():
if k.endswith("rank"):
rankset.append(int(v))
assert len(set(rankset)) <= 1, "Rank should be the same per model"
if len(rankset) == 0:
rankset = [0]
total_rank += rankset[0]
_total_metadata.update(_metadata)
ranklist.append(rankset[0])
# remove metadata about tokens
for k, v in _total_metadata.items():
if v != "<embed>":
total_metadata[k] = v
tensorkeys = set()
for safelora in lora_safetenors:
tensorkeys.update(safelora.keys())
for keys in tensorkeys:
if keys.startswith("text_encoder") or keys.startswith("unet"):
tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors]
is_down = keys.endswith("down")
if is_down:
_tensor = torch.cat(tensorset, dim=0)
assert _tensor.shape[0] == total_rank
else:
_tensor = torch.cat(tensorset, dim=1)
assert _tensor.shape[1] == total_rank
total_tensor[keys] = _tensor
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
total_metadata[keys_rank] = str(total_rank)
token_size_list = []
for idx, safelora in enumerate(lora_safetenors):
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
for jdx, token in enumerate(sorted(tokens)):
total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"
print(f"Embedding {token} replaced to <s{idx}-{jdx}>")
token_size_list.append(len(tokens))
return total_tensor, total_metadata, ranklist, token_size_list
class DummySafeTensorObject:
def __init__(self, tensor: dict, metadata):
self.tensor = tensor
self._metadata = metadata
def keys(self):
return self.tensor.keys()
def metadata(self):
return self._metadata
def get_tensor(self, key):
return self.tensor[key]
class LoRAManager:
def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline):
self.lora_paths_list = lora_paths_list
self.pipe = pipe
self._setup()
def _setup(self):
self._lora_safetenors = [
safe_open(path, framework="pt", device="cpu")
for path in self.lora_paths_list
]
(
total_tensor,
total_metadata,
self.ranklist,
self.token_size_list,
) = lora_join(self._lora_safetenors)
self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata)
monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora)
tok_dict = parse_safeloras_embeds(self.total_safelora)
apply_learned_embed_in_clip(
tok_dict,
self.pipe.text_encoder,
self.pipe.tokenizer,
token=None,
idempotent=True,
)
def tune(self, scales):
assert len(scales) == len(
self.ranklist
), "Scale list should be the same length as ranklist"
diags = []
for scale, rank in zip(scales, self.ranklist):
diags = diags + [scale] * rank
set_lora_diag(self.pipe.unet, torch.tensor(diags))
def prompt(self, prompt):
if prompt is not None:
for idx, tok_size in enumerate(self.token_size_list):
prompt = prompt.replace(
f"<{idx + 1}>",
"".join([f"<s{idx}-{jdx}>" for jdx in range(tok_size)]),
)
# TODO : Rescale LoRA + Text inputs based on prompt scale params
return prompt