zhangyang-0123's picture
add flux demo
7fb327a
import os
import pickle
from copy import deepcopy
from typing import Optional
import torch
from diffusers.models.activations import GEGLU, GELU
from cross_attn_hook import CrossAttentionExtractionHook
from ffn_hooker import FeedForwardHooker
from norm_attn_hook import NormHooker
# create dummy module for skip connection
class SkipConnection(torch.nn.Module):
def __init__(self):
super(SkipConnection, self).__init__()
def forward(*args, **kwargs):
return args[1]
def calculate_mask_sparsity(hooker, threshold: Optional[float] = None):
total_num_lambs = 0
num_activate_lambs = 0
binary = getattr(
hooker, "binary", None
) # if binary is not present, it will return None for ff_hooks
for lamb in hooker.lambs:
total_num_lambs += lamb.size(0)
if binary:
assert threshold is None, "threshold should be None for binary mask"
num_activate_lambs += lamb.sum().item()
else:
assert (
threshold is not None
), "threshold must be provided for non-binary mask"
num_activate_lambs += (lamb >= threshold).sum().item()
return total_num_lambs, num_activate_lambs, num_activate_lambs / total_num_lambs
def create_pipeline(
pipe,
model_id,
device,
torch_dtype,
save_pt=None,
lambda_threshold: float = 1,
binary=True,
epsilon=0.0,
masking="binary",
attn_name="attn",
return_hooker=False,
scope=None,
ratio=None,
):
"""
create the pipeline and optionally load the saved mask
"""
pipe.to(device)
pipe.vae.requires_grad_(False)
if hasattr(pipe, "unet"):
pipe.unet.requires_grad_(False)
else:
pipe.transformer.requires_grad_(False)
if save_pt:
# TODO should merge all the hooks checkpoint into one
if "ff.pt" in save_pt or "attn.pt" in save_pt:
save_pts = get_save_pts(save_pt)
cross_attn_hooker = CrossAttentionExtractionHook(
pipe,
model_name=model_id,
regex=".*",
dtype=torch_dtype,
head_num_filter=1,
masking=masking, # need to change to binary during inference
dst=save_pts["attn"],
epsilon=epsilon,
attn_name=attn_name,
binary=binary,
)
cross_attn_hooker.add_hooks(init_value=1)
ff_hooker = FeedForwardHooker(
pipe,
regex=".*",
dtype=torch_dtype,
masking=masking,
dst=save_pts["ff"],
epsilon=epsilon,
binary=binary,
)
ff_hooker.add_hooks(init_value=1)
if os.path.exists(save_pts["norm"]):
norm_hooker = NormHooker(
pipe,
regex=".*",
dtype=torch_dtype,
masking=masking,
dst=save_pts["norm"],
epsilon=epsilon,
binary=binary,
)
norm_hooker.add_hooks(init_value=1)
else:
norm_hooker = None
_ = pipe("abc", num_inference_steps=1)
cross_attn_hooker.load(device=device, threshold=lambda_threshold)
ff_hooker.load(device=device, threshold=lambda_threshold)
if norm_hooker:
norm_hooker.load(device=device, threshold=lambda_threshold)
if scope == "local" or scope == "global":
if isinstance(ratio, float):
attn_hooker_ratio = ratio
ff_hooker_ratio = ratio
else:
attn_hooker_ratio, ff_hooker_ratio = ratio[0], ratio[1]
if norm_hooker:
if len(ratio) < 3:
raise ValueError("Need to provide ratio for norm layer")
norm_hooker_ratio = ratio[2]
cross_attn_hooker.binarize(scope, attn_hooker_ratio)
ff_hooker.binarize(scope, ff_hooker_ratio)
if norm_hooker:
norm_hooker.binarize(scope, norm_hooker_ratio)
hookers = [cross_attn_hooker, ff_hooker]
if norm_hooker:
hookers.append(norm_hooker)
if return_hooker:
return pipe, hookers
else:
return pipe
def linear_layer_pruning(module, lamb):
heads_to_keep = torch.nonzero(lamb).squeeze()
if len(heads_to_keep.shape) == 0:
# if only one head is kept, or none
heads_to_keep = heads_to_keep.unsqueeze(0)
modules_to_remove = [module.to_k, module.to_q, module.to_v]
new_heads = int(lamb.sum().item())
if new_heads == 0:
return SkipConnection()
for module_to_remove in modules_to_remove:
# get head dimension
inner_dim = module_to_remove.out_features // module.heads
# place holder for the rows to keep
rows_to_keep = torch.zeros(
module_to_remove.out_features,
dtype=torch.bool,
device=module_to_remove.weight.device,
)
for idx in heads_to_keep:
rows_to_keep[idx * inner_dim : (idx + 1) * inner_dim] = True
# overwrite the inner projection with masked projection
module_to_remove.weight.data = module_to_remove.weight.data[rows_to_keep, :]
if module_to_remove.bias is not None:
module_to_remove.bias.data = module_to_remove.bias.data[rows_to_keep]
module_to_remove.out_features = int(sum(rows_to_keep).item())
# Also update the output projection layer if available, (for FLUXSingleAttnProcessor2_0)
# with column masking, dim 1
if getattr(module, "to_out", None) is not None:
module.to_out[0].weight.data = module.to_out[0].weight.data[:, rows_to_keep]
module.to_out[0].in_features = int(sum(rows_to_keep).item())
# update parameters in the attention module
module.inner_dim = module.inner_dim // module.heads * new_heads
try:
module.query_dim = module.query_dim // module.heads * new_heads
module.inner_kv_dim = module.inner_kv_dim // module.heads * new_heads
except:
pass
module.cross_attention_dim = module.cross_attention_dim // module.heads * new_heads
module.heads = new_heads
return module
def ffn_linear_layer_pruning(module, lamb):
lambda_to_keep = torch.nonzero(lamb).squeeze()
if len(lambda_to_keep) == 0:
return SkipConnection()
num_lambda = len(lambda_to_keep)
if isinstance(module.net[0], GELU):
# linear layer weight remove before activation
module.net[0].proj.weight.data = module.net[0].proj.weight.data[
lambda_to_keep, :
]
module.net[0].proj.out_features = num_lambda
if module.net[0].proj.bias is not None:
module.net[0].proj.bias.data = module.net[0].proj.bias.data[lambda_to_keep]
update_act = GELU(module.net[0].proj.in_features, num_lambda)
update_act.proj = module.net[0].proj
module.net[0] = update_act
elif isinstance(module.net[0], GEGLU):
output_feature = module.net[0].proj.out_features
module.net[0].proj.weight.data = torch.cat(
[
module.net[0].proj.weight.data[: output_feature // 2, :][
lambda_to_keep, :
],
module.net[0].proj.weight.data[output_feature // 2 :][
lambda_to_keep, :
],
],
dim=0,
)
module.net[0].proj.out_features = num_lambda * 2
if module.net[0].proj.bias is not None:
module.net[0].proj.bias.data = torch.cat(
[
module.net[0].proj.bias.data[: output_feature // 2][lambda_to_keep],
module.net[0].proj.bias.data[output_feature // 2 :][lambda_to_keep],
]
)
update_act = GEGLU(module.net[0].proj.in_features, num_lambda * 2)
update_act.proj = module.net[0].proj
module.net[0] = update_act
# proj weight after activation
module.net[2].weight.data = module.net[2].weight.data[:, lambda_to_keep]
module.net[2].in_features = num_lambda
return module
# create SparsityLinear module
class SparsityLinear(torch.nn.Module):
def __init__(self, in_features, out_features, lambda_to_keep, num_lambda):
super(SparsityLinear, self).__init__()
self.linear = torch.nn.Linear(in_features, num_lambda)
self.out_features = out_features
self.lambda_to_keep = lambda_to_keep
def forward(self, x):
x = self.linear(x)
output = torch.zeros(
x.size(0), self.out_features, device=x.device, dtype=x.dtype
)
output[:, self.lambda_to_keep] = x
return output
def norm_layer_pruning(module, lamb):
"""
Pruning the layer normalization layer for FLUX model
"""
lambda_to_keep = torch.nonzero(lamb).squeeze()
if len(lambda_to_keep) == 0:
return SkipConnection()
num_lambda = len(lambda_to_keep)
# get num_features
in_features = module.linear.in_features
out_features = module.linear.out_features
linear = SparsityLinear(in_features, out_features, lambda_to_keep, num_lambda)
linear.linear.weight.data = module.linear.weight.data[lambda_to_keep]
linear.linear.bias.data = module.linear.bias.data[lambda_to_keep]
module.linear = linear
return module
def get_save_pts(save_pt):
if "ff.pt" in save_pt:
ff_save_pt = deepcopy(save_pt) # avoid in-place operation
attn_save_pt = save_pt.split(os.sep)
attn_save_pt[-1] = attn_save_pt[-1].replace("ff", "attn")
attn_save_pt_output = os.sep.join(attn_save_pt)
attn_save_pt[-1] = attn_save_pt[-1].replace("attn", "norm")
norm_save_pt = os.sep.join(attn_save_pt)
return {
"ff": ff_save_pt,
"attn": attn_save_pt_output,
"norm": norm_save_pt,
}
else:
attn_save_pt = deepcopy(save_pt)
ff_save_pt = save_pt.split(os.sep)
ff_save_pt[-1] = ff_save_pt[-1].replace("attn", "ff")
ff_save_pt_output = os.sep.join(ff_save_pt)
ff_save_pt[-1] = ff_save_pt[-1].replace("ff", "norm")
norm_save_pt = os.sep.join(attn_save_pt)
return {
"ff": ff_save_pt_output,
"attn": attn_save_pt,
"norm": norm_save_pt,
}
def save_img(pipe, g_cpu, steps, prompt, save_path):
image = pipe(prompt, generator=g_cpu, num_inference_steps=steps)
image["images"][0].save(save_path)