Spaces:
Running
on
Zero
Running
on
Zero
# ref: | |
# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py | |
# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py | |
import os | |
import math | |
from typing import Optional, List, Type, Set, Literal | |
import torch | |
import torch.nn as nn | |
from diffusers import UNet2DConditionModel | |
from safetensors.torch import save_file | |
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [ | |
# "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2 | |
"Attention" | |
] | |
UNET_TARGET_REPLACE_MODULE_CONV = [ | |
"ResnetBlock2D", | |
"Downsample2D", | |
"Upsample2D", | |
"DownBlock2D", | |
"UpBlock2D", | |
] # locon, 3clier | |
LORA_PREFIX_UNET = "lora_unet" | |
DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER | |
TRAINING_METHODS = Literal[ | |
"noxattn", # train all layers except x-attns and time_embed layers | |
"innoxattn", # train all layers except self attention layers | |
"selfattn", # ESD-u, train only self attention layers | |
"xattn", # ESD-x, train only x attention layers | |
"full", # train all layers | |
"xattn-strict", # q and k values | |
"noxattn-hspace", | |
"noxattn-hspace-last", | |
# "xlayer", | |
# "outxattn", | |
# "outsattn", | |
# "inxattn", | |
# "inmidsattn", | |
# "selflayer", | |
] | |
class LoRAModule(nn.Module): | |
""" | |
replaces forward method of the original Linear, instead of replacing the original Linear module. | |
""" | |
def __init__( | |
self, | |
lora_name, | |
proj, | |
v, | |
mean, | |
std, | |
org_module: nn.Module, | |
multiplier=1.0, | |
lora_dim=4, | |
alpha=1, | |
): | |
"""if alpha == 0 or None, alpha is rank (no scaling).""" | |
super().__init__() | |
self.lora_name = lora_name | |
self.lora_dim = lora_dim | |
self.in_dim = org_module.in_features | |
self.out_dim = org_module.out_features | |
self.proj = proj.bfloat16() | |
self.mean1 = mean[0:self.in_dim].bfloat16() | |
self.mean2 = mean[self.in_dim:].bfloat16() | |
self.std1 = std[0:self.in_dim].bfloat16() | |
self.std2 = std[self.in_dim:].bfloat16() | |
self.v1 = v[0:self.in_dim].bfloat16() | |
self.v2 = v[self.in_dim: ].bfloat16() | |
if type(alpha) == torch.Tensor: | |
alpha = alpha.detach().numpy() | |
alpha = lora_dim if alpha is None or alpha == 0 else alpha | |
self.scale = alpha / self.lora_dim | |
#self.scale = self.scale.bfloat16() | |
self.multiplier = multiplier | |
self.org_module = org_module | |
def apply_to(self): | |
self.org_forward = self.org_module.forward | |
self.org_module.forward = self.forward | |
del self.org_module | |
def forward(self, x): | |
return self.org_forward(x) +\ | |
(x@((self.proj@self.v1.T)*self.std1+self.mean1).T)@(((self.proj@self.v2.T)*self.std2+self.mean2))*self.multiplier*self.scale | |
class LoRAw2w(nn.Module): | |
def __init__( | |
self, | |
proj, | |
mean, | |
std, | |
v, | |
unet: UNet2DConditionModel, | |
rank: int = 4, | |
multiplier: torch.bfloat16= 1.0, | |
alpha: torch.bfloat16 = 1.0, | |
train_method: TRAINING_METHODS = "full" | |
) -> None: | |
super().__init__() | |
self.lora_scale = 1 | |
self.multiplier = multiplier | |
self.lora_dim = rank | |
self.alpha = alpha | |
self.proj = torch.nn.Parameter(proj) | |
self.register_buffer("mean", torch.tensor(mean)) | |
self.register_buffer("std", torch.tensor(std)) | |
self.register_buffer("v", torch.tensor(v)) | |
self.module = LoRAModule | |
self.unet_loras = self.create_modules( | |
LORA_PREFIX_UNET, | |
unet, | |
DEFAULT_TARGET_REPLACE, | |
self.lora_dim, | |
self.multiplier, | |
train_method=train_method, | |
) | |
self.lora_names = set() | |
for lora in self.unet_loras: | |
assert ( | |
lora.lora_name not in self.lora_names | |
), f"duplicated lora name: {lora.lora_name}. {self.lora_names}" | |
self.lora_names.add(lora.lora_name) | |
for lora in self.unet_loras: | |
lora.apply_to() | |
self.add_module( | |
lora.lora_name, | |
lora, | |
) | |
del unet | |
torch.cuda.empty_cache() | |
def reset(self): | |
for lora in self.unet_loras: | |
lora.proj = torch.nn.Parameter(self.proj.bfloat16()) | |
def create_modules( | |
self, | |
prefix: str, | |
root_module: nn.Module, | |
target_replace_modules: List[str], | |
rank: int, | |
multiplier: float, | |
train_method: TRAINING_METHODS, | |
) -> list: | |
counter = 0 | |
mm = [] | |
nn = [] | |
for name, module in root_module.named_modules(): | |
nn.append(name) | |
mm.append(module) | |
midstart = 0 | |
upstart = 0 | |
for i in range(len(nn)): | |
if "mid_block" in nn[i]: | |
midstart = i | |
break | |
for i in range(len(nn)): | |
if "up_block" in nn[i]: | |
upstart = i | |
break | |
mm = mm[:upstart]+mm[midstart:]+mm[upstart:midstart] | |
nn = nn[:upstart]+nn[midstart:]+nn[upstart:midstart] | |
loras = [] | |
names = [] | |
for i in range(len(mm)): | |
name = nn[i] | |
module = mm[i] | |
if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習 | |
if "attn2" in name or "time_embed" in name: | |
continue | |
elif train_method == "innoxattn": # Cross Attention | |
if "attn2" in name: | |
continue | |
elif train_method == "selfattn": # Self Attention | |
if "attn1" not in name: | |
continue | |
elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention | |
if "to_k" in name: | |
continue | |
elif train_method == "full": # 全部学習 | |
pass | |
else: | |
raise NotImplementedError( | |
f"train_method: {train_method} is not implemented." | |
) | |
if module.__class__.__name__ in target_replace_modules: | |
for child_name, child_module in module.named_modules(): | |
if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]: | |
if train_method == 'xattn-strict': | |
if 'out' in child_name: | |
continue | |
if "to_k" in child_name: | |
continue | |
if train_method == 'noxattn-hspace': | |
if 'mid_block' not in name: | |
continue | |
if train_method == 'noxattn-hspace-last': | |
if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name: | |
continue | |
lora_name = prefix + "." + name + "." + child_name | |
lora_name = lora_name.replace(".", "_") | |
in_dim = child_module.in_features | |
out_dim = child_module.out_features | |
combined_dim = in_dim+out_dim | |
lora = self.module( | |
lora_name, self.proj, self.v[counter:counter+combined_dim], self.mean[counter:counter+combined_dim],\ | |
self.std[counter:counter+combined_dim], child_module, multiplier, rank, self.alpha) | |
counter+=combined_dim | |
if lora_name not in names: | |
loras.append(lora) | |
names.append(lora_name) | |
return loras | |
def prepare_optimizer_params(self): | |
all_params = [] | |
if self.unet_loras: # 実質これしかない | |
params = [] | |
[params.extend(lora.parameters()) for lora in self.unet_loras] | |
param_data = {"params": params} | |
all_params.append(param_data) | |
return all_params | |
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): | |
state_dict = self.state_dict() | |
if dtype is not None: | |
for key in list(state_dict.keys()): | |
v = state_dict[key] | |
v = v.detach().clone().to("cpu").to(dtype) | |
state_dict[key] = v | |
if os.path.splitext(file)[1] == ".safetensors": | |
save_file(state_dict, file, metadata) | |
else: | |
torch.save(state_dict, file) | |
def set_lora_slider(self, scale): | |
self.lora_scale = scale | |
def __enter__(self): | |
for lora in self.unet_loras: | |
lora.multiplier = 1.0 * self.lora_scale | |
def __exit__(self, exc_type, exc_value, tb): | |
for lora in self.unet_loras: | |
lora.multiplier = 0 | |