Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import math | |
import types | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import cv2 | |
import re | |
import torch.nn.functional as F | |
from einops import rearrange | |
from einops.layers.torch import Rearrange | |
from PIL import Image | |
def extract_first_sentence(text): | |
end_index = text.find('.') | |
if end_index != -1: | |
first_sentence = text[:end_index + 1] | |
return first_sentence.strip() | |
else: | |
return text.strip() | |
import re | |
def remove_duplicate_keywords(text, keywords): | |
keyword_counts = {} | |
words = re.findall(r'\b\w+\b|[.,;!?]', text) | |
for keyword in keywords: | |
keyword_counts[keyword] = 0 | |
for i, word in enumerate(words): | |
if word.lower() == keyword.lower(): | |
keyword_counts[keyword] += 1 | |
if keyword_counts[keyword] > 1: | |
words[i] = "" | |
processed_text = " ".join(words) | |
return processed_text | |
def process_text_with_markers(text, parsing_mask_list): | |
keywords = ["face", "ears", "eyes", "nose", "mouth"] | |
text = remove_duplicate_keywords(text, keywords) | |
key_parsing_mask_markers = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"] | |
mapping = { | |
"Face": "face", | |
"Left_Ear": "ears", | |
"Right_Ear": "ears", | |
"Left_Eye": "eyes", | |
"Right_Eye": "eyes", | |
"Nose": "nose", | |
"Upper_Lip": "mouth", | |
"Lower_Lip": "mouth", | |
} | |
facial_features_align = [] | |
markers_align = [] | |
for key in key_parsing_mask_markers: | |
if key in parsing_mask_list: | |
mapped_key = mapping.get(key, key.lower()) | |
if mapped_key not in facial_features_align: | |
facial_features_align.append(mapped_key) | |
markers_align.append("<|"+mapped_key+"|>") | |
text_marked = text | |
align_parsing_mask_list = parsing_mask_list | |
for feature, marker in zip(facial_features_align[::-1], markers_align[::-1]): | |
pattern = rf'\b{feature}\b' | |
text_marked_new = re.sub(pattern, f'{feature} {marker}', text_marked, count=1) | |
if text_marked == text_marked_new: | |
for key, value in mapping.items(): | |
if value == feature: | |
if key in align_parsing_mask_list: | |
del align_parsing_mask_list[key] | |
text_marked = text_marked_new | |
text_marked = text_marked.replace('\n', '') | |
ordered_text = [] | |
text_none_makers = [] | |
facial_marked_count = 0 | |
skip_count = 0 | |
for marker in markers_align: | |
start_idx = text_marked.find(marker) | |
end_idx = start_idx + len(marker) | |
while start_idx > 0 and text_marked[start_idx - 1] not in [",", ".", ";"]: | |
start_idx -= 1 | |
while end_idx < len(text_marked) and text_marked[end_idx] not in [",", ".", ";"]: | |
end_idx += 1 | |
context = text_marked[start_idx:end_idx].strip() | |
if context == "": | |
text_none_makers.append(text_marked[:end_idx]) | |
else: | |
if skip_count!=0: | |
skip_count -= 1 | |
continue | |
else: | |
ordered_text.append(context + ",") | |
text_delete_makers = text_marked[:start_idx] + text_marked[end_idx:] | |
text_marked = text_delete_makers | |
facial_marked_count += 1 | |
align_marked_text = " ".join(ordered_text) | |
replace_list = ["<|face|>", "<|ears|>", "<|nose|>", "<|eyes|>", "<|mouth|>"] | |
for item in replace_list: | |
align_marked_text = align_marked_text.replace(item, "<|facial|>") | |
return align_marked_text, align_parsing_mask_list | |
def tokenize_and_mask_noun_phrases_ends(text, image_token_id, facial_token_id, tokenizer): | |
input_ids = tokenizer.encode(text) | |
image_noun_phrase_end_mask = [False for _ in input_ids] | |
facial_noun_phrase_end_mask = [False for _ in input_ids] | |
clean_input_ids = [] | |
clean_index = 0 | |
image_num = 0 | |
for i, id in enumerate(input_ids): | |
if id == image_token_id: | |
image_noun_phrase_end_mask[clean_index + image_num - 1] = True | |
image_num += 1 | |
elif id == facial_token_id: | |
facial_noun_phrase_end_mask[clean_index - 1] = True | |
else: | |
clean_input_ids.append(id) | |
clean_index += 1 | |
max_len = tokenizer.model_max_length | |
if len(clean_input_ids) > max_len: | |
clean_input_ids = clean_input_ids[:max_len] | |
else: | |
clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * ( | |
max_len - len(clean_input_ids) | |
) | |
if len(image_noun_phrase_end_mask) > max_len: | |
image_noun_phrase_end_mask = image_noun_phrase_end_mask[:max_len] | |
else: | |
image_noun_phrase_end_mask = image_noun_phrase_end_mask + [False] * ( | |
max_len - len(image_noun_phrase_end_mask) | |
) | |
if len(facial_noun_phrase_end_mask) > max_len: | |
facial_noun_phrase_end_mask = facial_noun_phrase_end_mask[:max_len] | |
else: | |
facial_noun_phrase_end_mask = facial_noun_phrase_end_mask + [False] * ( | |
max_len - len(facial_noun_phrase_end_mask) | |
) | |
clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long) | |
image_noun_phrase_end_mask = torch.tensor(image_noun_phrase_end_mask, dtype=torch.bool) | |
facial_noun_phrase_end_mask = torch.tensor(facial_noun_phrase_end_mask, dtype=torch.bool) | |
return clean_input_ids.unsqueeze(0), image_noun_phrase_end_mask.unsqueeze(0), facial_noun_phrase_end_mask.unsqueeze(0) | |
def prepare_image_token_idx(image_token_mask, facial_token_mask, max_num_objects=2, max_num_facials=5): | |
image_token_idx = torch.nonzero(image_token_mask, as_tuple=True)[1] | |
image_token_idx_mask = torch.ones_like(image_token_idx, dtype=torch.bool) | |
if len(image_token_idx) < max_num_objects: | |
image_token_idx = torch.cat( | |
[ | |
image_token_idx, | |
torch.zeros(max_num_objects - len(image_token_idx), dtype=torch.long), | |
] | |
) | |
image_token_idx_mask = torch.cat( | |
[ | |
image_token_idx_mask, | |
torch.zeros( | |
max_num_objects - len(image_token_idx_mask), | |
dtype=torch.bool, | |
), | |
] | |
) | |
facial_token_idx = torch.nonzero(facial_token_mask, as_tuple=True)[1] | |
facial_token_idx_mask = torch.ones_like(facial_token_idx, dtype=torch.bool) | |
if len(facial_token_idx) < max_num_facials: | |
facial_token_idx = torch.cat( | |
[ | |
facial_token_idx, | |
torch.zeros(max_num_facials - len(facial_token_idx), dtype=torch.long), | |
] | |
) | |
facial_token_idx_mask = torch.cat( | |
[ | |
facial_token_idx_mask, | |
torch.zeros( | |
max_num_facials - len(facial_token_idx_mask), | |
dtype=torch.bool, | |
), | |
] | |
) | |
image_token_idx = image_token_idx.unsqueeze(0) | |
image_token_idx_mask = image_token_idx_mask.unsqueeze(0) | |
facial_token_idx = facial_token_idx.unsqueeze(0) | |
facial_token_idx_mask = facial_token_idx_mask.unsqueeze(0) | |
return image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask | |
def get_object_localization_loss_for_one_layer( | |
cross_attention_scores, | |
object_segmaps, | |
object_token_idx, | |
object_token_idx_mask, | |
loss_fn, | |
): | |
bxh, num_noise_latents, num_text_tokens = cross_attention_scores.shape | |
b, max_num_objects, _, _ = object_segmaps.shape | |
size = int(num_noise_latents**0.5) | |
object_segmaps = F.interpolate(object_segmaps, size=(size, size), mode="bilinear", antialias=True) | |
object_segmaps = object_segmaps.view( | |
b, max_num_objects, -1 | |
) | |
num_heads = bxh // b | |
cross_attention_scores = cross_attention_scores.view(b, num_heads, num_noise_latents, num_text_tokens) | |
object_token_attn_prob = torch.gather( | |
cross_attention_scores, | |
dim=3, | |
index=object_token_idx.view(b, 1, 1, max_num_objects).expand( | |
b, num_heads, num_noise_latents, max_num_objects | |
), | |
) | |
object_segmaps = ( | |
object_segmaps.permute(0, 2, 1) | |
.unsqueeze(1) | |
.expand(b, num_heads, num_noise_latents, max_num_objects) | |
) | |
loss = loss_fn(object_token_attn_prob, object_segmaps) | |
loss = loss * object_token_idx_mask.view(b, 1, max_num_objects) | |
object_token_cnt = object_token_idx_mask.sum(dim=1).view(b, 1) + 1e-5 | |
loss = (loss.sum(dim=2) / object_token_cnt).mean() | |
return loss | |
def get_object_localization_loss( | |
cross_attention_scores, | |
object_segmaps, | |
image_token_idx, | |
image_token_idx_mask, | |
loss_fn, | |
): | |
num_layers = len(cross_attention_scores) | |
loss = 0 | |
for k, v in cross_attention_scores.items(): | |
layer_loss = get_object_localization_loss_for_one_layer( | |
v, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn | |
) | |
loss += layer_loss | |
return loss / num_layers | |
def unet_store_cross_attention_scores(unet, attention_scores, layers=5): | |
from diffusers.models.attention_processor import Attention | |
UNET_LAYER_NAMES = [ | |
"down_blocks.0", | |
"down_blocks.1", | |
"down_blocks.2", | |
"mid_block", | |
"up_blocks.1", | |
"up_blocks.2", | |
"up_blocks.3", | |
] | |
start_layer = (len(UNET_LAYER_NAMES) - layers) // 2 | |
end_layer = start_layer + layers | |
applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer] | |
def make_new_get_attention_scores_fn(name): | |
def new_get_attention_scores(module, query, key, attention_mask=None): | |
attention_probs = module.old_get_attention_scores( | |
query, key, attention_mask | |
) | |
attention_scores[name] = attention_probs | |
return attention_probs | |
return new_get_attention_scores | |
for name, module in unet.named_modules(): | |
if isinstance(module, Attention) and "attn1" in name: | |
if not any(layer in name for layer in applicable_layers): | |
continue | |
module.old_get_attention_scores = module.get_attention_scores | |
module.get_attention_scores = types.MethodType( | |
make_new_get_attention_scores_fn(name), module | |
) | |
return unet | |
class BalancedL1Loss(nn.Module): | |
def __init__(self, threshold=1.0, normalize=False): | |
super().__init__() | |
self.threshold = threshold | |
self.normalize = normalize | |
def forward(self, object_token_attn_prob, object_segmaps): | |
if self.normalize: | |
object_token_attn_prob = object_token_attn_prob / ( | |
object_token_attn_prob.max(dim=2, keepdim=True)[0] + 1e-5 | |
) | |
background_segmaps = 1 - object_segmaps | |
background_segmaps_sum = background_segmaps.sum(dim=2) + 1e-5 | |
object_segmaps_sum = object_segmaps.sum(dim=2) + 1e-5 | |
background_loss = (object_token_attn_prob * background_segmaps).sum( | |
dim=2 | |
) / background_segmaps_sum | |
object_loss = (object_token_attn_prob * object_segmaps).sum( | |
dim=2 | |
) / object_segmaps_sum | |
return background_loss - object_loss | |
def fetch_mask_raw_image(raw_image, mask_image): | |
mask_image = mask_image.resize(raw_image.size) | |
mask_raw_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (0, 0, 0)), mask_image) | |
return mask_raw_image | |
mapping_table = [ | |
{"Mask Value": 0, "Body Part": "Background", "RGB Color": [0, 0, 0]}, | |
{"Mask Value": 1, "Body Part": "Face", "RGB Color": [255, 0, 0]}, | |
{"Mask Value": 2, "Body Part": "Left_Eyebrow", "RGB Color": [255, 85, 0]}, | |
{"Mask Value": 3, "Body Part": "Right_Eyebrow", "RGB Color": [255, 170, 0]}, | |
{"Mask Value": 4, "Body Part": "Left_Eye", "RGB Color": [255, 0, 85]}, | |
{"Mask Value": 5, "Body Part": "Right_Eye", "RGB Color": [255, 0, 170]}, | |
{"Mask Value": 6, "Body Part": "Hair", "RGB Color": [0, 0, 255]}, | |
{"Mask Value": 7, "Body Part": "Left_Ear", "RGB Color": [85, 0, 255]}, | |
{"Mask Value": 8, "Body Part": "Right_Ear", "RGB Color": [170, 0, 255]}, | |
{"Mask Value": 9, "Body Part": "Mouth_External Contour", "RGB Color": [0, 255, 85]}, | |
{"Mask Value": 10, "Body Part": "Nose", "RGB Color": [0, 255, 0]}, | |
{"Mask Value": 11, "Body Part": "Mouth_Inner_Contour", "RGB Color": [0, 255, 170]}, | |
{"Mask Value": 12, "Body Part": "Upper_Lip", "RGB Color": [85, 255, 0]}, | |
{"Mask Value": 13, "Body Part": "Lower_Lip", "RGB Color": [170, 255, 0]}, | |
{"Mask Value": 14, "Body Part": "Neck", "RGB Color": [0, 85, 255]}, | |
{"Mask Value": 15, "Body Part": "Neck_Inner Contour", "RGB Color": [0, 170, 255]}, | |
{"Mask Value": 16, "Body Part": "Cloth", "RGB Color": [255, 255, 0]}, | |
{"Mask Value": 17, "Body Part": "Hat", "RGB Color": [255, 0, 255]}, | |
{"Mask Value": 18, "Body Part": "Earring", "RGB Color": [255, 85, 255]}, | |
{"Mask Value": 19, "Body Part": "Necklace", "RGB Color": [255, 255, 85]}, | |
{"Mask Value": 20, "Body Part": "Glasses", "RGB Color": [255, 170, 255]}, | |
{"Mask Value": 21, "Body Part": "Hand", "RGB Color": [255, 0, 255]}, | |
{"Mask Value": 22, "Body Part": "Wristband", "RGB Color": [0, 255, 255]}, | |
{"Mask Value": 23, "Body Part": "Clothes_Upper", "RGB Color": [85, 255, 255]}, | |
{"Mask Value": 24, "Body Part": "Clothes_Lower", "RGB Color": [170, 255, 255]} | |
] | |
def masks_for_unique_values(image_raw_mask): | |
image_array = np.array(image_raw_mask) | |
unique_values, counts = np.unique(image_array, return_counts=True) | |
masks_dict = {} | |
for value in unique_values: | |
binary_image = np.uint8(image_array == value) * 255 | |
contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
mask = np.zeros_like(image_array) | |
for contour in contours: | |
cv2.drawContours(mask, [contour], -1, (255), thickness=cv2.FILLED) | |
if value == 0: | |
body_part="WithoutBackground" | |
mask2 = np.where(mask == 255, 0, 255).astype(mask.dtype) | |
masks_dict[body_part] = Image.fromarray(mask2) | |
body_part = next((entry["Body Part"] for entry in mapping_table if entry["Mask Value"] == value), f"Unknown_{value}") | |
if body_part.startswith("Unknown_"): | |
continue | |
masks_dict[body_part] = Image.fromarray(mask) | |
return masks_dict | |
# FFN | |
def FeedForward(dim, mult=4): | |
inner_dim = int(dim * mult) | |
return nn.Sequential( | |
nn.LayerNorm(dim), | |
nn.Linear(dim, inner_dim, bias=False), | |
nn.GELU(), | |
nn.Linear(inner_dim, dim, bias=False), | |
) | |
def reshape_tensor(x, heads): | |
bs, length, width = x.shape | |
x = x.view(bs, length, heads, -1) | |
x = x.transpose(1, 2) | |
x = x.reshape(bs, heads, length, -1) | |
return x | |
class PerceiverAttention(nn.Module): | |
def __init__(self, *, dim, dim_head=64, heads=8): | |
super().__init__() | |
self.scale = dim_head**-0.5 | |
self.dim_head = dim_head | |
self.heads = heads | |
inner_dim = dim_head * heads | |
self.norm1 = nn.LayerNorm(dim) | |
self.norm2 = nn.LayerNorm(dim) | |
self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) | |
self.to_out = nn.Linear(inner_dim, dim, bias=False) | |
def forward(self, x, latents): | |
""" | |
Args: | |
x (torch.Tensor): image features | |
shape (b, n1, D) | |
latent (torch.Tensor): latent features | |
shape (b, n2, D) | |
""" | |
x = self.norm1(x) | |
latents = self.norm2(latents) | |
b, l, _ = latents.shape | |
q = self.to_q(latents) | |
kv_input = torch.cat((x, latents), dim=-2) | |
k, v = self.to_kv(kv_input).chunk(2, dim=-1) | |
q = reshape_tensor(q, self.heads) | |
k = reshape_tensor(k, self.heads) | |
v = reshape_tensor(v, self.heads) | |
# attention | |
scale = 1 / math.sqrt(math.sqrt(self.dim_head)) | |
weight = (q * scale) @ (k * scale).transpose(-2, -1) | |
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) | |
out = weight @ v | |
out = out.permute(0, 2, 1, 3).reshape(b, l, -1) | |
return self.to_out(out) | |
class FacePerceiverResampler(torch.nn.Module): | |
def __init__( | |
self, | |
*, | |
dim=768, | |
depth=4, | |
dim_head=64, | |
heads=16, | |
embedding_dim=1280, | |
output_dim=768, | |
ff_mult=4, | |
): | |
super().__init__() | |
self.proj_in = torch.nn.Linear(embedding_dim, dim) | |
self.proj_out = torch.nn.Linear(dim, output_dim) | |
self.norm_out = torch.nn.LayerNorm(output_dim) | |
self.layers = torch.nn.ModuleList([]) | |
for _ in range(depth): | |
self.layers.append( | |
torch.nn.ModuleList( | |
[ | |
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), | |
FeedForward(dim=dim, mult=ff_mult), | |
] | |
) | |
) | |
def forward(self, latents, x): # latents.torch.Size([2, 4, 768]) x.torch.Size([2, 257, 1280]) | |
x = self.proj_in(x) # x.torch.Size([2, 257, 768]) | |
for attn, ff in self.layers: | |
latents = attn(x, latents) + latents # latents.torch.Size([2, 4, 768]) | |
latents = ff(latents) + latents # latents.torch.Size([2, 4, 768]) | |
latents = self.proj_out(latents) | |
return self.norm_out(latents) | |
class ProjPlusModel(torch.nn.Module): | |
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): | |
super().__init__() | |
self.cross_attention_dim = cross_attention_dim | |
self.num_tokens = num_tokens | |
self.proj = torch.nn.Sequential( | |
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), | |
torch.nn.GELU(), | |
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), | |
) | |
self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
self.perceiver_resampler = FacePerceiverResampler( | |
dim=cross_attention_dim, | |
depth=4, | |
dim_head=64, | |
heads=cross_attention_dim // 64, | |
embedding_dim=clip_embeddings_dim, | |
output_dim=cross_attention_dim, | |
ff_mult=4, | |
) | |
def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): | |
x = self.proj(id_embeds) | |
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) | |
x = self.norm(x) | |
out = self.perceiver_resampler(x, clip_embeds) | |
if shortcut: | |
out = scale * x + out | |
return out | |
class AttentionMLP(nn.Module): | |
def __init__( | |
self, | |
dtype=torch.float16, | |
dim=1024, | |
depth=8, | |
dim_head=64, | |
heads=16, | |
single_num_tokens=1, | |
embedding_dim=1280, | |
output_dim=768, | |
ff_mult=4, | |
max_seq_len: int = 257*2, | |
apply_pos_emb: bool = False, | |
num_latents_mean_pooled: int = 0, | |
): | |
super().__init__() | |
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None | |
self.single_num_tokens = single_num_tokens | |
self.latents = nn.Parameter(torch.randn(1, self.single_num_tokens, dim) / dim**0.5) | |
self.proj_in = nn.Linear(embedding_dim, dim) | |
self.proj_out = nn.Linear(dim, output_dim) | |
self.norm_out = nn.LayerNorm(output_dim) | |
self.to_latents_from_mean_pooled_seq = ( | |
nn.Sequential( | |
nn.LayerNorm(dim), | |
nn.Linear(dim, dim * num_latents_mean_pooled), | |
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), | |
) | |
if num_latents_mean_pooled > 0 | |
else None | |
) | |
self.layers = nn.ModuleList([]) | |
for _ in range(depth): | |
self.layers.append( | |
nn.ModuleList( | |
[ | |
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), | |
FeedForward(dim=dim, mult=ff_mult), | |
] | |
) | |
) | |
def forward(self, x): | |
if self.pos_emb is not None: | |
n, device = x.shape[1], x.device | |
pos_emb = self.pos_emb(torch.arange(n, device=device)) | |
x = x + pos_emb | |
# x torch.Size([5, 257, 1280]) | |
latents = self.latents.repeat(x.size(0), 1, 1) | |
x = self.proj_in(x) # torch.Size([5, 257, 1024]) | |
if self.to_latents_from_mean_pooled_seq: | |
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) | |
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) | |
latents = torch.cat((meanpooled_latents, latents), dim=-2) | |
for attn, ff in self.layers: | |
latents = attn(x, latents) + latents | |
latents = ff(latents) + latents | |
latents = self.proj_out(latents) | |
return self.norm_out(latents) | |
def masked_mean(t, *, dim, mask=None): | |
if mask is None: | |
return t.mean(dim=dim) | |
denom = mask.sum(dim=dim, keepdim=True) | |
mask = rearrange(mask, "b n -> b n 1") | |
masked_t = t.masked_fill(~mask, 0.0) | |
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) | |