diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..b36ffd329d9a393b52d2c799f165567d9671247e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +asserts/example_images/1.png filter=lfs diff=lfs merge=lfs -text diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d2de2966a96c11c869dc8f2abfdf470362108f12 --- /dev/null +++ b/app.py @@ -0,0 +1,369 @@ +import os +import math +import time +import numpy +import random +import threading +import gradio as gr +from PIL import Image, ImageOps +from moviepy import VideoFileClip +from datetime import datetime, timedelta +from huggingface_hub import hf_hub_download, snapshot_download + +import insightface +from insightface.app import FaceAnalysis +from facexlib.parsing import init_parsing_model +from facexlib.utils.face_restoration_helper import FaceRestoreHelper + +import torch +from diffusers import CogVideoXDPMScheduler +from diffusers.utils import load_image +from diffusers.image_processor import VaeImageProcessor +from diffusers.training_utils import free_memory + +from util.utils import * +from util.rife_model import load_rife_model, rife_inference_with_latents +from models.utils import process_face_embeddings +from models.transformer_consisid import ConsisIDTransformer3DModel +from models.pipeline_consisid import ConsisIDPipeline +from models.eva_clip import create_model_and_transforms +from models.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from models.eva_clip.utils_qformer import resize_numpy_image_long + +device = "cuda" if torch.cuda.is_available() else "cpu" + +hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran") +snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") + +model_path = "BestWishYsh/ConsisID-preview" +lora_path = None +lora_rank = 128 +dtype = torch.bfloat16 + +if os.path.exists(os.path.join(model_path, "transformer_ema")): + subfolder = "transformer_ema" +else: + subfolder = "transformer" + +transformer = ConsisIDTransformer3DModel.from_pretrained_cus(model_path, subfolder=subfolder) +scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler") + +try: + is_kps = transformer.config.is_kps +except: + is_kps = False + +# 1. load face helper models +face_helper = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + device=device, + model_rootpath=os.path.join(model_path, "face_encoder") +) +face_helper.face_parse = None +face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device, model_rootpath=os.path.join(model_path, "face_encoder")) +face_helper.face_det.eval() +face_helper.face_parse.eval() + +model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), force_custom_clip=True) +face_clip_model = model.visual +face_clip_model.eval() + +eva_transform_mean = getattr(face_clip_model, 'image_mean', OPENAI_DATASET_MEAN) +eva_transform_std = getattr(face_clip_model, 'image_std', OPENAI_DATASET_STD) +if not isinstance(eva_transform_mean, (list, tuple)): + eva_transform_mean = (eva_transform_mean,) * 3 +if not isinstance(eva_transform_std, (list, tuple)): + eva_transform_std = (eva_transform_std,) * 3 +eva_transform_mean = eva_transform_mean +eva_transform_std = eva_transform_std + +face_main_model = FaceAnalysis(name='antelopev2', root=os.path.join(model_path, "face_encoder"), providers=['CUDAExecutionProvider']) +handler_ante = insightface.model_zoo.get_model(f'{model_path}/face_encoder/models/antelopev2/glintr100.onnx', providers=['CUDAExecutionProvider']) +face_main_model.prepare(ctx_id=0, det_size=(640, 640)) +handler_ante.prepare(ctx_id=0) + +face_clip_model.to(device, dtype=dtype) +face_helper.face_det.to(device) +face_helper.face_parse.to(device) +transformer.to(device, dtype=dtype) +free_memory() + +pipe = ConsisIDPipeline.from_pretrained(model_path, transformer=transformer, scheduler=scheduler, torch_dtype=dtype) +# If you're using with lora, add this code +if lora_path: + pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1") + pipe.fuse_lora(lora_scale=1 / lora_rank) + +scheduler_args = {} +if "variance_type" in pipe.scheduler.config: + variance_type = pipe.scheduler.config.variance_type + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + scheduler_args["variance_type"] = variance_type + +pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args) +pipe.to(device) + +os.makedirs("./output", exist_ok=True) +os.makedirs("./gradio_tmp", exist_ok=True) + +upscale_model = load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device) +frame_interpolation_model = load_rife_model("model_rife") + + +def infer( + prompt: str, + image_input: str, + num_inference_steps: int, + guidance_scale: float, + seed: int = 42, + progress=gr.Progress(track_tqdm=True), +): + if seed == -1: + seed = random.randint(0, 2**8 - 1) + + id_image = np.array(ImageOps.exif_transpose(Image.fromarray(image_input)).convert("RGB")) + id_image = resize_numpy_image_long(id_image, 1024) + id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper, face_clip_model, handler_ante, + eva_transform_mean, eva_transform_std, + face_main_model, device, dtype, id_image, + original_id_image=id_image, is_align_face=True, + cal_uncond=False) + + if is_kps: + kps_cond = face_kps + else: + kps_cond = None + + tensor = align_crop_face_image.cpu().detach() + tensor = tensor.squeeze() + tensor = tensor.permute(1, 2, 0) + tensor = tensor.numpy() * 255 + tensor = tensor.astype(np.uint8) + image = ImageOps.exif_transpose(Image.fromarray(tensor)) + + prompt = prompt.strip('"') + + generator = torch.Generator(device).manual_seed(seed) if seed else None + + video_pt = pipe( + prompt=prompt, + image=image, + num_videos_per_prompt=1, + num_inference_steps=num_inference_steps, + num_frames=49, + use_dynamic_cfg=False, + guidance_scale=guidance_scale, + generator=generator, + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + kps_cond=kps_cond, + output_type="pt", + ).frames + + free_memory() + return (video_pt, seed) + + +def convert_to_gif(video_path): + clip = VideoFileClip(video_path) + gif_path = video_path.replace(".mp4", ".gif") + clip.write_gif(gif_path, fps=8) + return gif_path + + +def delete_old_files(): + while True: + now = datetime.now() + cutoff = now - timedelta(minutes=10) + directories = ["./output", "./gradio_tmp"] + + for directory in directories: + for filename in os.listdir(directory): + file_path = os.path.join(directory, filename) + if os.path.isfile(file_path): + file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) + if file_mtime < cutoff: + os.remove(file_path) + time.sleep(600) + + +threading.Thread(target=delete_old_files, daemon=True).start() +examples_images = [ + ["asserts/example_images/1.png", "A woman adorned with a delicate flower crown, is standing amidst a field of gently swaying wildflowers. Her eyes sparkle with a serene gaze, and a faint smile graces her lips, suggesting a moment of peaceful contentment. The shot is framed from the waist up, highlighting the gentle breeze lightly tousling her hair. The background reveals an expansive meadow under a bright blue sky, capturing the tranquility of a sunny afternoon."], + ["asserts/example_images/2.png", "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel."], + ["asserts/example_images/3.png", "The video depicts a man sitting at an office desk, engaged in his work. He is dressed in a formal suit and appears to be focused on his computer screen. The office environment is well-organized, with shelves filled with binders and other office supplies neatly arranged. The man is holding a red cup, possibly containing a beverage, which he drinks from before setting it down on the desk. He then proceeds to type on the keyboard, indicating that he is working on something on his computer. The overall atmosphere of the video suggests a professional setting where the man is diligently working on his tasks."] +] + +with gr.Blocks() as demo: + gr.Markdown(""" +
+ ConsisID Space🤗 +
+
+ 🤗 Model Hub | + 📚 Dataset | + 🌐 Github | + 📝 Page | + 📜 arxiv +
+
+ If the Space is too busy, duplicate it to use privately + +
+
+ ⚠️ This demo is for academic research and experiential use only. +
+ """) + with gr.Row(): + with gr.Column(): + with gr.Accordion("IPT2V: Face Input", open=True): + image_input = gr.Image(label="Input Image (should contain clear face)") + prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5) + with gr.Accordion("Examples", open=False): + examples_component_images = gr.Examples( + examples_images, + inputs=[image_input, prompt], + cache_examples=False, + ) + + with gr.Group(): + with gr.Column(): + with gr.Row(): + seed_param = gr.Number( + label="Inference Seed (Enter a positive number, -1 for random)", value=42 + ) + with gr.Row(): + enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False) + enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False) + gr.Markdown( + "✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution)." + ) + + generate_button = gr.Button("🎬 Generate Video") + + with gr.Column(): + video_output = gr.Video(label="ConsisID Generate Video", width=720, height=480) + with gr.Row(): + download_video_button = gr.File(label="📥 Download Video", visible=False) + download_gif_button = gr.File(label="📥 Download GIF", visible=False) + seed_text = gr.Number(label="Seed Used for Video Generation", visible=False) + + gr.Markdown(""" + +
+ 🎥 Video Gallery +
+ + + + + + + + + + + + + + + + + + + + + + + + +
+

The video features a woman in exquisite hybrid armor adorned with iridescent gemstones, standing amidst gently falling cherry blossoms. Her piercing yet serene gaze hints at quiet determination, as a breeze catches a loose strand of her hair. She stands in a tranquil courtyard framed by moss-covered stone walls and wooden arches, with blossoms casting soft shadows on the ground. The petals swirl around her, adding a dreamlike quality, while the blurred backdrop emphasizes her poised figure. The scene conveys elegance, strength, and tranquil readiness, capturing a moment of peace before an upcoming challenge.

+
+ + +

The video features a baby wearing a bright superhero cape, standing confidently with arms raised in a powerful pose. The baby has a determined look on their face, with eyes wide and lips pursed in concentration, as if ready to take on a challenge. The setting appears playful, with colorful toys scattered around and a soft rug underfoot, while sunlight streams through a nearby window, highlighting the fluttering cape and adding to the impression of heroism. The overall atmosphere is lighthearted and fun, with the baby's expressions capturing a mix of innocence and an adorable attempt at bravery, as if truly ready to save the day.

+
+ +
+

The video features a man standing next to an airplane, engaged in a conversation on his cell phone. he is wearing sunglasses and a black top, and he appears to be talking seriously. The airplane has a green stripe running along its side, and there is a large engine visible behind his. The man seems to be standing near the entrance of the airplane, possibly preparing to board or just having disembarked. The setting suggests that he might be at an airport or a private airfield. The overall atmosphere of the video is professional and focused, with the man's attire and the presence of the airplane indicating a business or travel context.

+
+ + +

The video features a woman with blonde hair standing on a beach near the water's edge. She is wearing a black swimsuit and appears to be enjoying her time by the sea. The sky above is clear with some clouds, and the ocean waves gently lap against the shore. The woman seems to be holding something white in her hand, possibly a piece of driftwood or a small object found on the beach. The overall atmosphere of the video is serene and relaxing, capturing the beauty of nature and the simple pleasure of being by the ocean.

+
+ +
+

The video features a man sitting in a red armchair, enjoying a cup of coffee or tea. he is dressed in a light-colored outfit and has long dark-haired hair. The setting appears to be indoors, with large windows providing a view of a misty or foggy coastal landscape outside. The room has a modern design with geometric structures visible in the background. There is a small round table next to the armchair, also holding a cup. The overall atmosphere suggests a calm and serene moment, possibly during a cold or rainy day by the sea.

+
+ + +

The video shows a young boy sitting at a table, eating a piece of food. He appears to be enjoying his meal, as he takes a bite and chews it. The boy is wearing a blue shirt and has short hair. The background is dark, with some light coming from the left side of the frame. There is a straw visible on the right side of the frame, suggesting that there may be a drink next to the boy's plate. The overall atmosphere of the video seems casual and relaxed, with the focus on the boy's enjoyment of his food.

+
+ +
+

The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.

+
+ + +

The video features a young man standing outdoors in a snowy park. he is wearing a colorful winter jacket with a floral pattern and a white knit hat. The background shows a snowy landscape with trees, benches, and a metal fence. The ground is covered in snow, and there is a light snowfall in the air. The man appears to be enjoying the winter weather, as he smiles and gives a thumbs-up gesture towards the camera. The overall atmosphere of the video is cheerful and festive, capturing the beauty of a snowy day in a park.

+
+ +
+ """) + + def generate( + prompt, + image_input, + seed_value, + scale_status, + rife_status, + progress=gr.Progress(track_tqdm=True) + ): + latents, seed = infer( + prompt, + image_input, + num_inference_steps=50, + guidance_scale=7.0, + seed=seed_value, + progress=progress, + ) + if scale_status: + latents = upscale_batch_and_concatenate(upscale_model, latents, device) + if rife_status: + latents = rife_inference_with_latents(frame_interpolation_model, latents) + + batch_size = latents.shape[0] + batch_video_frames = [] + for batch_idx in range(batch_size): + pt_image = latents[batch_idx] + pt_image = torch.stack([pt_image[i] for i in range(pt_image.shape[0])]) + + image_np = VaeImageProcessor.pt_to_numpy(pt_image) + image_pil = VaeImageProcessor.numpy_to_pil(image_np) + batch_video_frames.append(image_pil) + + video_path = save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6)) + video_update = gr.update(visible=True, value=video_path) + gif_path = convert_to_gif(video_path) + gif_update = gr.update(visible=True, value=gif_path) + seed_update = gr.update(visible=True, value=seed) + + return video_path, video_update, gif_update, seed_update + + generate_button.click( + generate, + inputs=[prompt, image_input, seed_param, enable_scale, enable_rife], + outputs=[video_output, download_video_button, download_gif_button, seed_text], + ) + +if __name__ == "__main__": + demo.queue(max_size=15) + demo.launch() diff --git a/asserts/example_images/1.png b/asserts/example_images/1.png new file mode 100644 index 0000000000000000000000000000000000000000..846825db715fb8ad83e168ce76d5d032ebf548aa --- /dev/null +++ b/asserts/example_images/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:434856739faadf1c89bc38d8b940fcdbe027595de89645f721d8585fc2fe2459 +size 1639385 diff --git a/asserts/example_images/2.png b/asserts/example_images/2.png new file mode 100644 index 0000000000000000000000000000000000000000..cf68b72128b70cf448e8a589e5e78f0ef10c3efc Binary files /dev/null and b/asserts/example_images/2.png differ diff --git a/asserts/example_images/3.png b/asserts/example_images/3.png new file mode 100644 index 0000000000000000000000000000000000000000..422c567fa537359740343fa78b83d25b7b357d32 Binary files /dev/null and b/asserts/example_images/3.png differ diff --git a/models/eva_clip/__init__.py b/models/eva_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa2d014bbfe644b1e247758116bbf1b184738fe5 --- /dev/null +++ b/models/eva_clip/__init__.py @@ -0,0 +1,11 @@ +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss +from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\ + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ + get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .tokenizer import SimpleTokenizer, tokenize +from .transform import image_transform \ No newline at end of file diff --git a/models/eva_clip/bpe_simple_vocab_16e6.txt.gz b/models/eva_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/models/eva_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/models/eva_clip/constants.py b/models/eva_clip/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee --- /dev/null +++ b/models/eva_clip/constants.py @@ -0,0 +1,2 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/models/eva_clip/eva_vit_model.py b/models/eva_clip/eva_vit_model.py new file mode 100644 index 0000000000000000000000000000000000000000..51db88cf0c7b5d7a43f2be80bc59abb6c859c4b4 --- /dev/null +++ b/models/eva_clip/eva_vit_model.py @@ -0,0 +1,548 @@ +# -------------------------------------------------------- +# Adapted from https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- +import math +import os +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +try: + from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +except: + from timm.layers import drop_path, to_2tuple, trunc_normal_ + +from .transformer import PatchDropout +from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast + +if os.getenv('ENV_TYPE') == 'deepspeed': + try: + from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint + except: + from torch.utils.checkpoint import checkpoint +else: + from torch.utils.checkpoint import checkpoint + +try: + import xformers + import xformers.ops as xops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + drop=0., + subln=False, + + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.ffn_ln(x) + + x = self.fc2(x) + x = self.drop(x) + return x + +class SwiGLU(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0., + norm_layer=nn.LayerNorm, subln=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.w1 = nn.Linear(in_features, hidden_features) + self.w2 = nn.Linear(in_features, hidden_features) + + self.act = act_layer() + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + self.w3 = nn.Linear(hidden_features, out_features) + + self.drop = nn.Dropout(drop) + + def forward(self, x): + x1 = self.w1(x) + x2 = self.w2(x) + hidden = self.act(x1) * x2 + x = self.ffn_ln(hidden) + x = self.w3(x) + x = self.drop(x) + return x + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.subln = subln + if self.subln: + self.q_proj = nn.Linear(dim, all_head_dim, bias=False) + self.k_proj = nn.Linear(dim, all_head_dim, bias=False) + self.v_proj = nn.Linear(dim, all_head_dim, bias=False) + else: + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity() + # self.proj = nn.Linear(all_head_dim, all_head_dim) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + + self.rope = rope + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + B, N, C = x.shape + if self.subln: + q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias) + k = F.linear(input=x, weight=self.k_proj.weight, bias=None) + v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias) + + q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C + k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + else: + + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C + q, k, v = qkv[0], qkv[1], qkv[2] + + if self.rope: + # slightly fast impl + q_t = q[:, :, 1:, :] + ro_q_t = self.rope(q_t) + q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) + + k_t = k[:, :, 1:, :] + ro_k_t = self.rope(k_t) + k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) + + if self.xattn: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + x = xops.memory_efficient_attention( + q, k, v, + p=self.xattn_drop, + scale=self.scale, + ) + x = x.reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0).type_as(attn) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias.type_as(attn) + + if attn_mask is not None: + attn_mask = attn_mask.bool() + attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False, + subln=False, naiveswiglu=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, + xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + if naiveswiglu: + self.mlp = SwiGLU( + in_features=dim, + hidden_features=mlp_hidden_dim, + subln=subln, + norm_layer=norm_layer, + ) + else: + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + subln=subln, + drop=drop + ) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + self.postnorm = postnorm + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + if self.gamma_1 is None: + if self.postnorm: + x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + if self.postnorm: + x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class EVAVisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0., + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False, + use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False, + pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False): + super().__init__() + + if not XFORMERS_IS_AVAILBLE: + xattn = False + + self.image_size = img_size + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + + if rope: + half_head_dim = embed_dim // num_heads // 2 + hw_seq_len = img_size // patch_size + self.rope = VisionRotaryEmbeddingFast( + dim=half_head_dim, + pt_seq_len=pt_hw_seq_len, + ft_seq_len=hw_seq_len if intp_freq else None, + # patch_dropout=patch_dropout + ) + else: + self.rope = None + + self.naiveswiglu = naiveswiglu + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, + xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu) + for i in range(depth)]) + self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) + self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + trunc_normal_(self.cls_token, std=.02) + # trunc_normal_(self.mask_token, std=.02) + + self.apply(self._init_weights) + self.fix_init_weight() + + if isinstance(self.head, nn.Linear): + trunc_normal_(self.head.weight, std=.02) + self.head.weight.data.mul_(init_scale) + self.head.bias.data.mul_(init_scale) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.grad_checkpointing = grad_checkpointing + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + if self.naiveswiglu: + rescale(layer.mlp.w3.weight.data, layer_id + 1) + else: + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_cast_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False): + + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + if shuffle: + idx = torch.randperm(x.shape[1]) + 1 + zero = torch.LongTensor([0, ]) + idx = torch.cat([zero, idx]) + pos_embed = self.pos_embed[:, idx] + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if shuffle: + x = x + pos_embed + elif self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + if os.getenv('RoPE') == '1': + if self.training and not isinstance(self.patch_dropout, nn.Identity): + x, patch_indices_keep = self.patch_dropout(x) + self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep) + else: + self.rope.forward = partial(self.rope.forward, patch_indices_keep=None) + x = self.patch_dropout(x) + else: + x = self.patch_dropout(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + hidden_states = [] + for idx, blk in enumerate(self.blocks): + if (0 < idx <= 20) and (idx % 4 == 0) and return_hidden: + hidden_states.append(x) + if self.grad_checkpointing: + x = checkpoint(blk, x, (rel_pos_bias,)) + else: + x = blk(x, rel_pos_bias=rel_pos_bias) + + if not return_all_features: + x = self.norm(x) + if self.fc_norm is not None: + return self.fc_norm(x.mean(1)), hidden_states + else: + return x[:, 0], hidden_states + return x + + def forward(self, x, return_all_features=False, return_hidden=False, shuffle=False): + if return_all_features: + return self.forward_features(x, return_all_features, return_hidden, shuffle) + x, hidden_states = self.forward_features(x, return_all_features, return_hidden, shuffle) + x = self.head(x) + if return_hidden: + return x, hidden_states + return x diff --git a/models/eva_clip/factory.py b/models/eva_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..ced8999997bf374b69f846bc73ea635fe8a6eb63 --- /dev/null +++ b/models/eva_clip/factory.py @@ -0,0 +1,517 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from typing import Optional, Tuple, Union, Dict, Any +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ + get_cast_dtype +from .openai import load_openai_model +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model +from .transform import image_transform +from .tokenizer import HFTokenizer, tokenize +from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed + + +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = ('.json',) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f'*{ext}')) + + for cf in config_files: + with open(cf, "r", encoding="utf8") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """ add model config path or file and update registry """ + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def get_tokenizer(model_name): + config = get_model_config(model_name) + tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize + return tokenizer + + +# loading openai CLIP weights when is_openai=True for training +def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]): + if is_openai: + model = torch.jit.load(checkpoint_path, map_location="cpu").eval() + state_dict = model.state_dict() + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + else: + checkpoint = torch.load(checkpoint_path, map_location=map_location) + for mk in model_key.split('|'): + if isinstance(checkpoint, dict) and mk in checkpoint: + state_dict = checkpoint[mk] + break + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + for k in skip_list: + if k in list(state_dict.keys()): + logging.info(f"Removing key {k} from pretrained checkpoint") + del state_dict[k] + + if os.getenv('RoPE') == '1': + for k in list(state_dict.keys()): + if 'freqs_cos' in k or 'freqs_sin' in k: + del state_dict[k] + return state_dict + + + +def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True): + state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'): + state_dict['logit_scale'] = state_dict['text.logit_scale'] + del state_dict['text.logit_scale'] + + # resize_clip_pos_embed for CLIP and open CLIP + if 'visual.positional_embedding' in state_dict: + resize_clip_pos_embed(state_dict, model) + # specified to eva_vit_model + elif 'visual.pos_embed' in state_dict: + resize_evaclip_pos_embed(state_dict, model) + + # resize_clip_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}") + return incompatible_keys + +def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]): + state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) + + for k in list(state_dict.keys()): + if not k.startswith('visual.'): + del state_dict[k] + for k in list(state_dict.keys()): + if k.startswith('visual.'): + new_k = k[7:] + state_dict[new_k] = state_dict[k] + del state_dict[k] + return state_dict + +def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]): + state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) + + for k in list(state_dict.keys()): + if k.startswith('visual.'): + del state_dict[k] + return state_dict + +def get_pretrained_tag(pretrained_model): + pretrained_model = pretrained_model.lower() + if "laion" in pretrained_model or "open_clip" in pretrained_model: + return "open_clip" + elif "openai" in pretrained_model: + return "clip" + elif "eva" in pretrained_model and "clip" in pretrained_model: + return "eva_clip" + else: + return "other" + +def load_pretrained_checkpoint( + model, + visual_checkpoint_path, + text_checkpoint_path, + strict=True, + visual_model=None, + text_model=None, + model_key="model|module|state_dict", + skip_list=[]): + visual_tag = get_pretrained_tag(visual_model) + text_tag = get_pretrained_tag(text_model) + + logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}") + visual_incompatible_keys, text_incompatible_keys = None, None + if visual_checkpoint_path: + if visual_tag == "eva_clip" or visual_tag == "open_clip": + visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list) + elif visual_tag == "clip": + visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list) + else: + visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list) + + # resize_clip_pos_embed for CLIP and open CLIP + if 'positional_embedding' in visual_state_dict: + resize_visual_pos_embed(visual_state_dict, model) + # specified to EVA model + elif 'pos_embed' in visual_state_dict: + resize_eva_pos_embed(visual_state_dict, model) + + visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict) + logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}") + logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}") + + if text_checkpoint_path: + if text_tag == "eva_clip" or text_tag == "open_clip": + text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list) + elif text_tag == "clip": + text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list) + else: + text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list) + + text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict) + + logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}") + logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}") + + return visual_incompatible_keys, text_incompatible_keys + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + pretrained_image: str = '', + pretrained_text: str = '', + pretrained_hf: bool = True, + pretrained_visual_model: str = None, + pretrained_text_model: str = None, + cache_dir: Optional[str] = None, + skip_list: list = [], +): + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == 'openai': + logging.info(f'Loading pretrained {model_name} from OpenAI.') + model = load_openai_model( + model_name, + precision=precision, + device=device, + jit=jit, + cache_dir=cache_dir, + ) + else: + model_cfg = get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') + else: + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if 'rope' in model_cfg.get('vision_cfg', {}): + if model_cfg['vision_cfg']['rope']: + os.environ['RoPE'] = "1" + else: + os.environ['RoPE'] = "0" + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout + + cast_dtype = get_cast_dtype(precision) + custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg']) + + + if custom_clip: + if 'hf_model_name' in model_cfg.get('text_cfg', {}): + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf + model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + pretrained_cfg = {} + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, + checkpoint_path, + model_key="model|module|state_dict", + strict=False + ) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) + else: + visual_checkpoint_path = '' + text_checkpoint_path = '' + + if pretrained_image: + pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names + pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image) + if 'timm_model_name' in model_cfg.get('vision_cfg', {}): + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + elif pretrained_image_cfg: + visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained_image): + visual_checkpoint_path = pretrained_image + else: + logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.') + raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.') + + if pretrained_text: + pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names + pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text) + if pretrained_image_cfg: + text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained_text): + text_checkpoint_path = pretrained_text + else: + logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.') + raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.') + + if visual_checkpoint_path: + logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).') + if text_checkpoint_path: + logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).') + + if visual_checkpoint_path or text_checkpoint_path: + load_pretrained_checkpoint( + model, + visual_checkpoint_path, + text_checkpoint_path, + strict=False, + visual_model=pretrained_visual_model, + text_model=pretrained_text_model, + model_key="model|module|state_dict", + skip_list=skip_list + ) + + if "fp16" in precision or "bf16" in precision: + logging.info(f'convert precision to {precision}') + model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16) + + model.to(device=device) + + # set image / mean metadata from pretrained_cfg if available, or use default + model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD + + if jit: + model = torch.jit.script(model) + + return model + + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + pretrained_image: str = '', + pretrained_text: str = '', + pretrained_hf: bool = True, + pretrained_visual_model: str = None, + pretrained_text_model: str = None, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + skip_list: list = [], +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_clip=force_custom_clip, + force_patch_dropout=force_patch_dropout, + pretrained_image=pretrained_image, + pretrained_text=pretrained_text, + pretrained_hf=pretrained_hf, + pretrained_visual_model=pretrained_visual_model, + pretrained_text_model=pretrained_text_model, + cache_dir=cache_dir, + skip_list=skip_list, + ) + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std + ) + + return model, preprocess_train, preprocess_val + + +def create_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + pretrained_image: str = '', + pretrained_text: str = '', + pretrained_hf: bool = True, + pretrained_visual_model: str = None, + pretrained_text_model: str = None, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + skip_list: list = [], +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_clip=force_custom_clip, + force_patch_dropout=force_patch_dropout, + pretrained_image=pretrained_image, + pretrained_text=pretrained_text, + pretrained_hf=pretrained_hf, + pretrained_visual_model=pretrained_visual_model, + pretrained_text_model=pretrained_text_model, + cache_dir=cache_dir, + skip_list=skip_list, + ) + + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std + ) + del model + + return preprocess_train, preprocess_val + +def create_model_from_pretrained( + model_name: str, + pretrained: str, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + is_frozen: bool = False, +): + if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained): + raise RuntimeError( + f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.' + f' Use open_clip.list_pretrained() to find one.') + + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_clip=force_custom_clip, + force_patch_dropout=force_patch_dropout, + cache_dir=cache_dir, + ) + + if is_frozen: + for param in model.parameters(): + param.requires_grad = False + + if not return_transform: + return model + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std + ) + + return model, preprocess diff --git a/models/eva_clip/hf_configs.py b/models/eva_clip/hf_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..a8c9b704db1879676aed5cef26796303b65fe987 --- /dev/null +++ b/models/eva_clip/hf_configs.py @@ -0,0 +1,57 @@ +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens" + }, + "pooler": "mean_pooler", + }, + "bert": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + } +} diff --git a/models/eva_clip/hf_model.py b/models/eva_clip/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b9fd85b4066ba31db2bda5767ed1ce15de479d --- /dev/null +++ b/models/eva_clip/hf_model.py @@ -0,0 +1,248 @@ +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" + +import re + +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch import TensorType +try: + import transformers + from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ + BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + + class BaseModelOutput: + pass + + + class PretrainedConfig: + pass + +from .hf_configs import arch_dict + +# utils +def _camel2snake(s): + return re.sub(r'(? TensorType: + # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device) + # attn_mask = (x != self.config.pad_token_id).long() + # out = self.transformer( + # input_ids=x, + # attention_mask=attn_mask, + # encoder_hidden_states = image_embeds, + # encoder_attention_mask = image_atts, + # ) + # pooled_out = self.pooler(out, attn_mask) + + # return self.itm_proj(pooled_out) + + def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None): + if masked_indices is None: + masked_indices = torch.bernoulli(probability_matrix).bool() + + masked_indices[input_ids == self.tokenizer.pad_token_id] = False + masked_indices[input_ids == self.tokenizer.cls_token_id] = False + + if targets is not None: + targets[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices + input_ids[indices_replaced] = self.tokenizer.mask_token_id + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced + random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) + input_ids[indices_random] = random_words[indices_random] + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + + if targets is not None: + return input_ids, targets + else: + return input_ids + + def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25): + labels = input_ids.clone() + attn_mask = (input_ids != self.config.pad_token_id).long() + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device) + vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"]) + probability_matrix = torch.full(labels.shape, mlm_probability) + input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels, + probability_matrix = probability_matrix) + mlm_output = self.transformer(input_ids, + attention_mask = attn_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + labels = labels, + ) + return mlm_output.loss + # mlm_output = self.transformer(input_ids, + # attention_mask = attn_mask, + # encoder_hidden_states = image_embeds, + # encoder_attention_mask = image_atts, + # return_dict = True, + # ).last_hidden_state + # logits = self.mlm_proj(mlm_output) + + # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size) + # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size) + # labels = labels[:, 1:].contiguous().view(-1) + + # mlm_loss = F.cross_entropy( + # logits, + # labels, + # # label_smoothing=0.1, + # ) + # return mlm_loss + + + def forward(self, x:TensorType) -> TensorType: + attn_mask = (x != self.config.pad_token_id).long() + out = self.transformer(input_ids=x, attention_mask=attn_mask) + pooled_out = self.pooler(out, attn_mask) + + return self.proj(pooled_out) + + def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True): + if not unlocked_layers: # full freezing + for n, p in self.transformer.named_parameters(): + p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False + return + + encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer + layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) + print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") + embeddings = getattr( + self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) + modules = [embeddings, *layer_list][:-unlocked_layers] + # freeze layers + for module in modules: + for n, p in module.named_parameters(): + p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False + + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.gradient_checkpointing_enable() + + def get_num_layers(self): + encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer + layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) + return len(layer_list) + + def init_parameters(self): + pass diff --git a/models/eva_clip/loss.py b/models/eva_clip/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..473f60d98d501067e85ace2dd089b00e249b6d17 --- /dev/null +++ b/models/eva_clip/loss.py @@ -0,0 +1,138 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F + +try: + import torch.distributed.nn + from torch import distributed as dist + has_distributed = True +except ImportError: + has_distributed = False + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from timm.loss import LabelSmoothingCrossEntropy + + +def gather_features( + image_features, + text_features, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False +): + assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' + if use_horovod: + assert hvd is not None, 'Please install horovod' + if gather_with_grad: + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + else: + with torch.no_grad(): + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) + gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) + all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) + # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) + # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) + else: + gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] + gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] + dist.all_gather(gathered_image_features, image_features) + dist.all_gather(gathered_text_features, text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + + return all_image_features, all_text_features + + +class ClipLoss(nn.Module): + + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + smoothing=0., + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None + + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward(self, image_features, text_features, logit_scale=1.): + device = image_features.device + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + # calculated ground-truth and cache if enabled + num_logits = logits_per_image.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + if self.label_smoothing_cross_entropy: + total_loss = ( + self.label_smoothing_cross_entropy(logits_per_image, labels) + + self.label_smoothing_cross_entropy(logits_per_text, labels) + ) / 2 + else: + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + acc = None + i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) + t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) + acc = {"i2t": i2t_acc, "t2i": t2i_acc} + return total_loss, acc \ No newline at end of file diff --git a/models/eva_clip/model.py b/models/eva_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..da3bbd755799ced672385d1029ba7ce6d5215b0b --- /dev/null +++ b/models/eva_clip/model.py @@ -0,0 +1,439 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Union +from functools import partial + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +try: + from .hf_model import HFTextEncoder +except: + HFTextEncoder = None +from .modified_resnet import ModifiedResNet +from .timm_model import TimmModel +from .eva_vit_model import EVAVisionTransformer +from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer + +try: + from apex.normalization import FusedLayerNorm +except: + FusedLayerNorm = LayerNorm + print("Please 'pip install apex'") + +try: + import xformers.ops as xops +except ImportError: + xops = None + print("Please 'pip install xformers'") + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + drop_path_rate: Optional[float] = None # drop path rate + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size + qkv_bias: bool = True + fusedLN: bool = False + xattn: bool = False + postnorm: bool = False + rope: bool = False + pt_hw_seq_len: int = 16 # 224/14 + intp_freq: bool = False + naiveswiglu: bool = False + subln: bool = False + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = 'mlp' + pooler_type: str = 'mean_pooler' + masked_language_modeling: bool = False + fusedLN: bool = False + xattn: bool = False + attn_mask: bool = True + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.eva_model_name: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNorm + + visual = EVAVisionTransformer( + img_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + num_classes=embed_dim, + use_mean_pooling=vision_cfg.global_average_pool, #False + init_values=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + embed_dim=vision_cfg.width, + depth=vision_cfg.layers, + num_heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + qkv_bias=vision_cfg.qkv_bias, + drop_path_rate=vision_cfg.drop_path_rate, + norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6), + xattn=vision_cfg.xattn, + rope=vision_cfg.rope, + postnorm=vision_cfg.postnorm, + pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14 + intp_freq= vision_cfg.intp_freq, + naiveswiglu= vision_cfg.naiveswiglu, + subln= vision_cfg.subln + ) + elif vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + embed_dim=embed_dim, + image_size=vision_cfg.image_size + ) + act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + global_average_pool=vision_cfg.global_average_pool, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder( + text_cfg.hf_model_name, + output_dim=embed_dim, + tokenizer_name=text_cfg.hf_tokenizer_name, + proj=text_cfg.proj, + pooler_type=text_cfg.pooler_type, + masked_language_modeling=text_cfg.masked_language_modeling + ) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer, + xattn=text_cfg.xattn, + attn_mask=text_cfg.attn_mask, + ) + return text + +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {'logit_scale'} + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + return image_features, text_features, self.logit_scale.exp() + + +class CustomCLIP(nn.Module): + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + itm_task: bool = False, + ): + super().__init__() + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + @torch.jit.ignore + def no_weight_decay(self): + return {'logit_scale'} + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr, None) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + if isinstance(l, nn.Parameter): + l.data = l.data.to(dtype) + + for name in ["text_projection", "proj"]: + if hasattr(l, name) and isinstance(l, nn.Parameter): + attr = getattr(l, name, None) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + 'logit_scale' + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,) + )) + model.visual.image_size = image_size + return model diff --git a/models/eva_clip/model_configs/EVA01-CLIP-B-16.json b/models/eva_clip/model_configs/EVA01-CLIP-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..aad2058003962a4ab286bf4e1ae956288af34e62 --- /dev/null +++ b/models/eva_clip/model_configs/EVA01-CLIP-B-16.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16, + "eva_model_name": "eva-clip-b-16", + "ls_init_value": 0.1, + "drop_path_rate": 0.0 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json b/models/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..100279572ff6d1bcca601f0eb526b4d4ff174c7d --- /dev/null +++ b/models/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14, + "eva_model_name": "eva-clip-g-14-x", + "drop_path_rate": 0, + "xattn": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/models/eva_clip/model_configs/EVA01-CLIP-g-14.json b/models/eva_clip/model_configs/EVA01-CLIP-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..5d338b4e6104241d1f0304ee82400035d5385332 --- /dev/null +++ b/models/eva_clip/model_configs/EVA01-CLIP-g-14.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14, + "eva_model_name": "eva-clip-g-14-x", + "drop_path_rate": 0.4, + "xattn": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/models/eva_clip/model_configs/EVA02-CLIP-B-16.json b/models/eva_clip/model_configs/EVA02-CLIP-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..e4a6e723f77033caa341ddf9b5be1787d64ad42c --- /dev/null +++ b/models/eva_clip/model_configs/EVA02-CLIP-B-16.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "head_width": 64, + "patch_size": 16, + "mlp_ratio": 2.6667, + "eva_model_name": "eva-clip-b-16-X", + "drop_path_rate": 0.0, + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "xattn": true, + "fusedLN": true + } +} \ No newline at end of file diff --git a/models/eva_clip/model_configs/EVA02-CLIP-L-14-336.json b/models/eva_clip/model_configs/EVA02-CLIP-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..3e1d124e1118911c5ad7b1ce85df195aca363ac4 --- /dev/null +++ b/models/eva_clip/model_configs/EVA02-CLIP-L-14-336.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "drop_path_rate": 0, + "head_width": 64, + "mlp_ratio": 2.6667, + "patch_size": 14, + "eva_model_name": "eva-clip-l-14-336", + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/models/eva_clip/model_configs/EVA02-CLIP-L-14.json b/models/eva_clip/model_configs/EVA02-CLIP-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..03b22ad3cfb92f9c843b9ec8d672e57e7a9ba4a2 --- /dev/null +++ b/models/eva_clip/model_configs/EVA02-CLIP-L-14.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "drop_path_rate": 0, + "head_width": 64, + "mlp_ratio": 2.6667, + "patch_size": 14, + "eva_model_name": "eva-clip-l-14", + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/models/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json b/models/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..aa04e2545ac1e015daae2c10133956ce969524f7 --- /dev/null +++ b/models/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 64, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.571428571428571, + "patch_size": 14, + "eva_model_name": "eva-clip-4b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": true + } +} diff --git a/models/eva_clip/model_configs/EVA02-CLIP-bigE-14.json b/models/eva_clip/model_configs/EVA02-CLIP-bigE-14.json new file mode 100644 index 0000000000000000000000000000000000000000..747ffccc8bd49dbb6701b58e15843b7fe3754e64 --- /dev/null +++ b/models/eva_clip/model_configs/EVA02-CLIP-bigE-14.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 64, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.571428571428571, + "patch_size": 14, + "eva_model_name": "eva-clip-4b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/models/eva_clip/modified_resnet.py b/models/eva_clip/modified_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..299080850061a0dab433322e5f8fe2a55fb4e9a2 --- /dev/null +++ b/models/eva_clip/modified_resnet.py @@ -0,0 +1,188 @@ +import os +import sys + +import torch +from torch import nn +from torch.nn import functional as F +from collections import OrderedDict + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path)] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0., + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/models/eva_clip/openai.py b/models/eva_clip/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4e13e876d6a7a3463b457e62c517cb063b1356 --- /dev/null +++ b/models/eva_clip/openai.py @@ -0,0 +1,144 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag('openai') + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + jit: bool = True, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' + + if get_pretrained_url(name, 'openai'): + model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + if precision.startswith('amp') or precision == 'fp32': + model.float() + elif precision == 'bf16': + convert_weights_to_lp(model, dtype=torch.bfloat16) + + return model + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 (typically for CPU) + if precision == 'fp32': + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + model.float() + + # ensure image_size attr available at consistent location for both jit and non-jit + model.visual.image_size = model.input_resolution.item() + return model diff --git a/models/eva_clip/pretrained.py b/models/eva_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e55dcf36a0e7dbd4c13b4ca2d7cb460e4c3547 --- /dev/null +++ b/models/eva_clip/pretrained.py @@ -0,0 +1,332 @@ +import hashlib +import os +import urllib +import warnings +from functools import partial +from typing import Dict, Union + +from tqdm import tqdm + +try: + from huggingface_hub import hf_hub_download + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', filename='', mean=None, std=None): + return dict( + url=url, + hf_hub=hf_hub, + mean=mean, + std=std, + ) + +_VITB32 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') +) + +_VITB32_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), +) + +_VITB16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), +) + +_EVAB16 = dict( + eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'), + eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'), + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'), + eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), +) + +_VITL14 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), +) + +_EVAL14 = dict( + eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'), + eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'), + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'), + eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'), +) + +_VITL14_336 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_EVAL14_336 = dict( + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'), + eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'), + eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'), + eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), +) + +_EVAg14 = dict( + eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'), + eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'), + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'), + eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'), +) + +_EVAg14_PLUS = dict( + eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'), + eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'), + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'), + eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), +) + +_EVAbigE14 = dict( + eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), + eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'), + eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'), +) + +_EVAbigE14_PLUS = dict( + eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), + eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'), + eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'), +) + + +_PRETRAINED = { + # "ViT-B-32": _VITB32, + "OpenaiCLIP-B-32": _VITB32, + "OpenCLIP-B-32": _VITB32, + + # "ViT-B-32-quickgelu": _VITB32_quickgelu, + "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu, + "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu, + + # "ViT-B-16": _VITB16, + "OpenaiCLIP-B-16": _VITB16, + "OpenCLIP-B-16": _VITB16, + + "EVA02-B-16": _EVAB16, + "EVA02-CLIP-B-16": _EVAB16, + + # "ViT-B-16-plus-240": _VITB16_PLUS_240, + "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240, + + # "ViT-L-14": _VITL14, + "OpenaiCLIP-L-14": _VITL14, + "OpenCLIP-L-14": _VITL14, + + "EVA02-L-14": _EVAL14, + "EVA02-CLIP-L-14": _EVAL14, + + # "ViT-L-14-336": _VITL14_336, + "OpenaiCLIP-L-14-336": _VITL14_336, + + "EVA02-CLIP-L-14-336": _EVAL14_336, + + # "ViT-H-14": _VITH14, + # "ViT-g-14": _VITg14, + "OpenCLIP-H-14": _VITH14, + "OpenCLIP-g-14": _VITg14, + + "EVA01-CLIP-g-14": _EVAg14, + "EVA01-CLIP-g-14-plus": _EVAg14_PLUS, + + # "ViT-bigG-14": _VITbigG14, + "OpenCLIP-bigG-14": _VITbigG14, + + "EVA02-CLIP-bigE-14": _EVAbigE14, + "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS, +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = 'open_clip_pytorch_model.bin', + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = '' + if not cfg: + return target + + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/models/eva_clip/rope.py b/models/eva_clip/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..69030c35ea7b6b4f298daebbee5717f3fa1254ab --- /dev/null +++ b/models/eva_clip/rope.py @@ -0,0 +1,137 @@ +from math import pi +import torch +from torch import nn +from einops import rearrange, repeat +import logging + +def broadcat(tensors, dim = -1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim = dim) + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') + + +class VisionRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs = None, + freqs_for = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f'unknown modality {freqs_for}') + + if ft_seq_len is None: ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs_h = torch.einsum('..., f -> ... f', t, freqs) + freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) + + freqs_w = torch.einsum('..., f -> ... f', t, freqs) + freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) + + freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) + + self.register_buffer("freqs_cos", freqs.cos()) + self.register_buffer("freqs_sin", freqs.sin()) + + logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') + + def forward(self, t, start_index = 0): + rot_dim = self.freqs_cos.shape[-1] + end_index = start_index + rot_dim + assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) + + return torch.cat((t_left, t, t_right), dim = -1) + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs = None, + freqs_for = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + patch_dropout = 0. + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f'unknown modality {freqs_for}') + + if ft_seq_len is None: ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum('..., f -> ... f', t, freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.patch_dropout = patch_dropout + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') + + def forward(self, t, patch_indices_keep=None): + if patch_indices_keep is not None: + batch = t.size()[0] + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) + freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) + + freqs_cos = freqs_cos[batch_indices, patch_indices_keep] + freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j') + freqs_sin = freqs_sin[batch_indices, patch_indices_keep] + freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j') + + return t * freqs_cos + rotate_half(t) * freqs_sin + + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin \ No newline at end of file diff --git a/models/eva_clip/timm_model.py b/models/eva_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b58122c0b84fbda9e51867342823222234e17505 --- /dev/null +++ b/models/eva_clip/timm_model.py @@ -0,0 +1,122 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """ timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool='avg', + proj='linear', + proj_bias=False, + drop=0., + pretrained=False): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + self.trunk = timm.create_model(model_name, pretrained=pretrained) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if pool in ('abs_attn', 'rot_attn'): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == 'abs_attn': + head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == 'rot_attn': + head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, 'projection layer needed if non-attention pooling is used.' + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == 'linear': + head_layers['drop'] = nn.Dropout(drop) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == 'mlp': + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """ lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/models/eva_clip/tokenizer.py b/models/eva_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..41482f82aebbf197f4ee4e6c07c845a0d69dd7d6 --- /dev/null +++ b/models/eva_clip/tokenizer.py @@ -0,0 +1,201 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', ''] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + "HuggingFace tokenizer wrapper" + def __init__(self, tokenizer_name:str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids + return input_ids diff --git a/models/eva_clip/transform.py b/models/eva_clip/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..39f3e4cf6cf9985131ae2ef254b59540904b02e7 --- /dev/null +++ b/models/eva_clip/transform.py @@ -0,0 +1,103 @@ +from typing import Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == 'min' else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + return img + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +# class CatGen(nn.Module): +# def __init__(self, num=4): +# self.num = num +# def mixgen_batch(image, text): +# batch_size = image.shape[0] +# index = np.random.permutation(batch_size) + +# cat_images = [] +# for i in range(batch_size): +# # image mixup +# image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] +# # text concat +# text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] +# text = torch.stack(text) +# return image, text + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + normalize = Normalize(mean=mean, std=std) + if is_train: + return Compose([ + RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), + _convert_to_rgb, + ToTensor(), + normalize, + ]) + else: + if resize_longest_max: + transforms = [ + ResizeMaxSize(image_size, fill=fill_color) + ] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend([ + _convert_to_rgb, + ToTensor(), + normalize, + ]) + return Compose(transforms) diff --git a/models/eva_clip/transformer.py b/models/eva_clip/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..33e89ff7aa8ff60ae65dcfc5d21cf9af4d214510 --- /dev/null +++ b/models/eva_clip/transformer.py @@ -0,0 +1,737 @@ +import os +import logging +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +try: + from timm.models.layers import trunc_normal_ +except: + from timm.layers import trunc_normal_ + +from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast +from .utils import to_2tuple + +if os.getenv('ENV_TYPE') == 'deepspeed': + try: + import deepspeed + from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint + except: + print("Please 'pip install deepspeed'") + deepspeed = None + from torch.utils.checkpoint import checkpoint +else: + from torch.utils.checkpoint import checkpoint + +try: + import xformers.ops as xops +except ImportError: + xops = None + print("Please 'pip install xformers'") + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor): + output = F.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(x) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}") + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + if self.training and os.getenv('RoPE') == '1': + return x, patch_indices_keep + + return x + + +def _in_projection_packed( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: Optional[torch.Tensor] = None, + ): + """ + https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726 + """ + E = q.size(-1) + if k is v: + if q is k: + # self-attention + return F.linear(q, w, b).chunk(3, dim=-1) + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0., + xattn=False, + rope=False + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + self.rope = rope + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + if self.xattn: + q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + + x = xops.memory_efficient_attention( + q, k, v, + p=self.xattn_drop, + scale=self.scale if self.logit_scale is None else None, + attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None, + ) + else: + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + +class CustomAttention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=True, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0., + xattn=False + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias) + N_q, B_q, C_q = q.shape + N_k, B_k, C_k = k.shape + N_v, B_v, C_v = v.shape + if self.xattn: + # B, N, C -> B, N, num_heads, C + q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1) + k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1) + v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1) + + x = xops.memory_efficient_attention( + q, k, v, + p=self.xattn_drop, + scale=self.scale if self.logit_scale is None else None, + attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None + ) + else: + # B*H, L, C + q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + # B*H, N_q, N_k + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale + attn = attn.view(-1, N_q, N_k) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + + if self.head_scale is not None: + x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale + x = x.view(-1, N_q, C_q) + x = x.transpose(0, 1).reshape(N_q, B_q, C_q) + x = self.out_proj(x) + x = self.out_drop(x) + return x + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + cross_attn: bool = False, + xattn: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1 + self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1 + self.attn = CustomAttention( + d_model, n_head, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + xattn=xattn + ) + + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask))) + q = q + self.ls_2(self.mlp(self.ln_2(q))) + return q + +class CustomTransformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = True, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + cross_attn: bool = False, + xattn: bool = False, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + self.xattn = xattn + + self.resblocks = nn.ModuleList([ + CustomResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + scale_cosine_attn=scale_cosine_attn, + scale_heads=scale_heads, + scale_attn=scale_attn, + scale_fc=scale_fc, + cross_attn=cross_attn, + xattn=xattn) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None): + if k is None and v is None: + k = v = q + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + q = checkpoint(r, q, k, v, attn_mask) + else: + q = r(q, k, v, attn_mask=attn_mask) + return q + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + if xattn: + self.attn = Attention(d_model, n_head, xattn=True) + else: + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + self.xattn = xattn + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + if self.xattn: + return self.attn(x, attn_mask=attn_mask) + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + patch_dropout: float = 0., + global_average_pool: bool = False, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + self.image_size = to_2tuple(image_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + self.ln_pre = norm_layer(width) + + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + xattn=xattn + ) + + self.global_average_pool = global_average_pool + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def get_num_layers(self): + return self.transformer.layers + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {'positional_embedding', 'class_embedding'} + + def forward(self, x: torch.Tensor, return_all_features: bool=False): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if not return_all_features: + if self.global_average_pool: + x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1) + else: + x = x[:, 0] + + x = self.ln_post(x) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class TextTransformer(nn.Module): + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool= False, + attn_mask: bool = True + ): + super().__init__() + self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + xattn=xattn + ) + + self.xattn = xattn + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if attn_mask: + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + else: + self.attn_mask = None + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + # return {'positional_embedding', 'token_embedding'} + return {'positional_embedding'} + + def get_num_layers(self): + return self.transformer.layers + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text, return_all_features: bool=False): + cast_dtype = self.transformer.get_cast_dtype() + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + # x = self.transformer(x) # no attention mask is applied + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if not return_all_features: + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return x diff --git a/models/eva_clip/utils.py b/models/eva_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc5a7a451fdf8911ebbc816afbd2664ff348836 --- /dev/null +++ b/models/eva_clip/utils.py @@ -0,0 +1,326 @@ +from itertools import repeat +import collections.abc +import logging +import math +import numpy as np + +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d +import torch.nn.functional as F + +# open CLIP +def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + align_corners=True, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed + + +def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + align_corners=True, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['positional_embedding'] = new_pos_embed + +def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): + all_keys = list(state_dict.keys()) + # interpolate position embedding + if 'visual.pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['visual.pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.visual.patch_embed.num_patches + num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['visual.pos_embed'] = new_pos_embed + + patch_embed_proj = state_dict['visual.patch_embed.proj.weight'] + patch_size = model.visual.patch_embed.patch_size + state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate( + patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False) + + +def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): + all_keys = list(state_dict.keys()) + # interpolate position embedding + if 'pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.visual.patch_embed.num_patches + num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + + patch_embed_proj = state_dict['patch_embed.proj.weight'] + patch_size = model.visual.patch_embed.patch_size + state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate( + patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False) + + +def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): + all_keys = list(state_dict.keys()) + for key in all_keys: + if "relative_position_index" in key: + state_dict.pop(key) + + if "relative_position_bias_table" in key: + rel_pos_bias = state_dict[key] + src_num_pos, num_attn_heads = rel_pos_bias.size() + dst_num_pos, _ = model.visual.state_dict()[key].size() + dst_patch_shape = model.visual.patch_embed.patch_shape + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError() + num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) + if src_size != dst_size: + print("Position interpolate for %s from %dx%d to %dx%d" % ( + key, src_size, src_size, dst_size, dst_size)) + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + return a * (1.0 - r ** n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + print("Original positions = %s" % str(x)) + print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(num_attn_heads): + z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() + f = F.interpolate.interp2d(x, y, z, kind='cubic') + all_rel_pos_bias.append( + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) + + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) + state_dict[key] = new_rel_pos_bias + + # interpolate position embedding + if 'pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.visual.patch_embed.num_patches + num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + + patch_embed_proj = state_dict['patch_embed.proj.weight'] + patch_size = model.visual.patch_embed.patch_size + state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate( + patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False) + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) + + +def is_logging(args): + def is_global_master(args): + return args.rank == 0 + + def is_local_master(args): + return args.local_rank == 0 + + def is_master(args, local=False): + return is_local_master(args) if local else is_global_master(args) + return is_master + + +class AllGather(torch.autograd.Function): + """An autograd function that performs allgather on a tensor. + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + + @staticmethod + def forward(ctx, tensor, rank, world_size): + tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(tensors_gather, tensor) + ctx.rank = rank + ctx.batch_size = tensor.shape[0] + return torch.cat(tensors_gather, 0) + + @staticmethod + def backward(ctx, grad_output): + return ( + grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)], + None, + None + ) + +allgather = AllGather.apply \ No newline at end of file diff --git a/models/eva_clip/utils_qformer.py b/models/eva_clip/utils_qformer.py new file mode 100644 index 0000000000000000000000000000000000000000..809767280a303b666ef98300f398877d219bc207 --- /dev/null +++ b/models/eva_clip/utils_qformer.py @@ -0,0 +1,166 @@ +import importlib +import math +import os +import random + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from torchvision.utils import make_grid +from transformers import PretrainedConfig + + +def seed_everything(seed): + os.environ["PL_GLOBAL_SEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def is_torch2_available(): + return hasattr(F, "scaled_dot_product_attention") + + +def instantiate_from_config(config): + if "target" not in config: + if config == '__is_first_stage__' or config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", {})) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def drop_seq_token(seq, drop_rate=0.5): + idx = torch.randperm(seq.size(1)) + num_keep_tokens = int(len(idx) * (1 - drop_rate)) + idx = idx[:num_keep_tokens] + seq = seq[:, idx] + return seq + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": # noqa RET505 + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def resize_numpy_image_long(image, resize_long_edge=768): + h, w = image.shape[:2] + if max(h, w) <= resize_long_edge: + return image + k = resize_long_edge / max(h, w) + h = int(h * k) + w = int(w * k) + image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) + return image + + +# from basicsr +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result diff --git a/models/local_facial_extractor.py b/models/local_facial_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..a2755b510560d53d28f102d763826db80d9f5f26 --- /dev/null +++ b/models/local_facial_extractor.py @@ -0,0 +1,269 @@ +import math +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None): + super().__init__() + self.scale = dim_head ** -0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, seq_len, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1) + + return self.to_out(out) + + +class LocalFacialExtractor(nn.Module): + def __init__( + self, + dim=1024, + depth=10, + dim_head=64, + heads=16, + num_id_token=5, + num_queries=32, + output_dim=2048, + ff_mult=4, + ): + """ + Initializes the LocalFacialExtractor class. + + Parameters: + - dim (int): The dimensionality of latent features. + - depth (int): Total number of PerceiverAttention and FeedForward layers. + - dim_head (int): Dimensionality of each attention head. + - heads (int): Number of attention heads. + - num_id_token (int): Number of tokens used for identity features. + - num_queries (int): Number of query tokens for the latent representation. + - output_dim (int): Output dimension after projection. + - ff_mult (int): Multiplier for the feed-forward network hidden dimension. + """ + super().__init__() + + # Storing identity token and query information + self.num_id_token = num_id_token + self.dim = dim + self.num_queries = num_queries + assert depth % 5 == 0 + self.depth = depth // 5 + scale = dim ** -0.5 + + # Learnable latent query embeddings + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale) + # Projection layer to map the latent output to the desired dimension + self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim)) + + # Attention and FeedForward layer stack + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer + FeedForward(dim=dim, mult=ff_mult), # FeedForward layer + ] + ) + ) + + # Mappings for each of the 5 different ViT features + for i in range(5): + setattr( + self, + f'mapping_{i}', + nn.Sequential( + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, dim), + ), + ) + + # Mapping for identity embedding vectors + self.id_embedding_mapping = nn.Sequential( + nn.Linear(1280, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, dim * num_id_token), + ) + + def forward(self, x, y): + """ + Forward pass for LocalFacialExtractor. + + Parameters: + - x (Tensor): The input identity embedding tensor of shape (batch_size, 1280). + - y (list of Tensor): A list of 5 visual feature tensors each of shape (batch_size, 1024). + + Returns: + - Tensor: The extracted latent features of shape (batch_size, num_queries, output_dim). + """ + + # Repeat latent queries for the batch size + latents = self.latents.repeat(x.size(0), 1, 1) + + # Map the identity embedding to tokens + x = self.id_embedding_mapping(x) + x = x.reshape(-1, self.num_id_token, self.dim) + + # Concatenate identity tokens with the latent queries + latents = torch.cat((latents, x), dim=1) + + # Process each of the 5 visual feature inputs + for i in range(5): + vit_feature = getattr(self, f'mapping_{i}')(y[i]) + ctx_feature = torch.cat((x, vit_feature), dim=1) + + # Pass through the PerceiverAttention and FeedForward layers + for attn, ff in self.layers[i * self.depth: (i + 1) * self.depth]: + latents = attn(ctx_feature, latents) + latents + latents = ff(latents) + latents + + # Retain only the query latents + latents = latents[:, :self.num_queries] + # Project the latents to the output dimension + latents = latents @ self.proj_out + return latents + + +class PerceiverCrossAttention(nn.Module): + """ + + Args: + dim (int): Dimension of the input latent and output. Default is 3072. + dim_head (int): Dimension of each attention head. Default is 128. + heads (int): Number of attention heads. Default is 16. + kv_dim (int): Dimension of the key/value input, allowing flexible cross-attention. Default is 2048. + + Attributes: + scale (float): Scaling factor used in dot-product attention for numerical stability. + norm1 (nn.LayerNorm): Layer normalization applied to the input image features. + norm2 (nn.LayerNorm): Layer normalization applied to the latent features. + to_q (nn.Linear): Linear layer for projecting the latent features into queries. + to_kv (nn.Linear): Linear layer for projecting the input features into keys and values. + to_out (nn.Linear): Linear layer for outputting the final result after attention. + + """ + def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048): + super().__init__() + self.scale = dim_head ** -0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + # Layer normalization to stabilize training + self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim) + self.norm2 = nn.LayerNorm(dim) + + # Linear transformations to produce queries, keys, and values + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + + Args: + x (torch.Tensor): Input image features with shape (batch_size, n1, D), where: + - batch_size (b): Number of samples in the batch. + - n1: Sequence length (e.g., number of patches or tokens). + - D: Feature dimension. + + latents (torch.Tensor): Latent feature representations with shape (batch_size, n2, D), where: + - n2: Number of latent elements. + + Returns: + torch.Tensor: Attention-modulated features with shape (batch_size, n2, D). + + """ + # Apply layer normalization to the input image and latent features + x = self.norm1(x) + latents = self.norm2(latents) + + b, seq_len, _ = latents.shape + + # Compute queries, keys, and values + q = self.to_q(latents) + k, v = self.to_kv(x).chunk(2, dim=-1) + + # Reshape tensors to split into attention heads + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # Compute attention weights + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable scaling than post-division + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # Compute the output via weighted combination of values + out = weight @ v + + # Reshape and permute to prepare for final linear transformation + out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1) + + return self.to_out(out) \ No newline at end of file diff --git a/models/pipeline_cogvideox.py b/models/pipeline_cogvideox.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd27c7dbf5df2bcdf7738e99889888a63d32a44 --- /dev/null +++ b/models/pipeline_cogvideox.py @@ -0,0 +1,748 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.loaders import CogVideoXLoraLoaderMixin +from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils import logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae_scaling_factor_image * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + id_vit_hidden: Optional[torch.Tensor] = None, + id_cond: Optional[torch.Tensor] = None, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + id_vit_hidden = id_vit_hidden, + id_cond = id_cond, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/models/pipeline_consisid.py b/models/pipeline_consisid.py new file mode 100644 index 0000000000000000000000000000000000000000..67680724d8cd5eb721e7eb4aa9f0ce7bdbd69949 --- /dev/null +++ b/models/pipeline_consisid.py @@ -0,0 +1,894 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import math +from typing import Callable, Dict, List, Optional, Tuple, Union + +import os +import sys +import PIL +import numpy as np +import cv2 +from PIL import Image +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import PipelineImageInput +from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils import logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput + +from models.transformer_consisid import ConsisIDTransformer3DModel + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(os.path.dirname(current_file_path))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import CogVideoXImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + >>> video = pipe(image, prompt, use_dynamic_cfg=True) + >>> export_to_video(video.frames[0], "output.mp4", fps=8) + ``` +""" + +def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + +def process_image(image, vae): + image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device) + image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype) + noisy_image = torch.randn_like(image) * image_noise_sigma[:, None, None, None, None] + input_image = image + noisy_image + image_latent_dist = vae.encode(input_image).latent_dist + return image_latent_dist + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class ConsisIDPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`ConsisIDTransformer3DModel`]): + A text conditioned `ConsisIDTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: Union[ConsisIDTransformer3DModel, CogVideoXTransformer3DModel], + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int = 1, + num_channels_latents: int = 16, + num_frames: int = 13, + height: int = 60, + width: int = 90, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + kps_cond: Optional[torch.Tensor] = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + image = image.unsqueeze(2) # [B, C, F, H, W] + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + if kps_cond is not None: + kps_cond = kps_cond.unsqueeze(2) + kps_cond_latents = [ + retrieve_latents(self.vae.encode(kps_cond[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + if kps_cond is not None: + kps_cond = kps_cond.unsqueeze(2) + kps_cond_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in kps_cond] + + image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = self.vae_scaling_factor_image * image_latents + + if kps_cond is not None: + kps_cond_latents = torch.cat(kps_cond_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + kps_cond_latents = self.vae_scaling_factor_image * kps_cond_latents + + padding_shape = ( + batch_size, + num_frames - 2, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + else: + padding_shape = ( + batch_size, + num_frames - 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) + if kps_cond is not None: + image_latents = torch.cat([image_latents, kps_cond_latents, latent_padding], dim=1) + else: + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, image_latents + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae_scaling_factor_image * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + id_vit_hidden: Optional[torch.Tensor] = None, + id_cond: Optional[torch.Tensor] = None, + kps_cond: Optional[torch.Tensor] = None, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image=image, + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + if kps_cond is not None: + kps_cond = draw_kps(image, kps_cond) + kps_cond = self.video_processor.preprocess(kps_cond, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + image = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + latent_channels = self.transformer.config.in_channels // 2 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + kps_cond + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + id_vit_hidden = id_vit_hidden, + id_cond = id_cond, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/models/transformer_consisid.py b/models/transformer_consisid.py new file mode 100644 index 0000000000000000000000000000000000000000..97aa48c6f0ee6ce7adec0a6c0c9da6c119f821d4 --- /dev/null +++ b/models/transformer_consisid.py @@ -0,0 +1,697 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Optional, Tuple, Union +import os +import sys +import json +import glob + +import torch +from torch import nn +from einops import rearrange, reduce + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.attention import Attention, FeedForward +from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 +from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero + +import os +import sys +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path)] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from local_facial_extractor import LocalFacialExtractor, PerceiverCrossAttention + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class CogVideoXBlock(nn.Module): + r""" + Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + time_embed_dim (`int`): + The number of channels in timestep embedding. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + qk_norm (`bool`, defaults to `True`): + Whether or not to use normalization after query and key projections in Attention. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*, defaults to `None`): + Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in Feed-forward layer. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in Attention output projection layer. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + dropout: float = 0.0, + activation_fn: str = "gelu-approximate", + attention_bias: bool = False, + qk_norm: bool = True, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + final_dropout: bool = True, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + + # 1. Self Attention + self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.attn1 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=attention_bias, + out_bias=attention_out_bias, + processor=CogVideoXAttnProcessor2_0(), + ) + + # 2. Feed Forward + self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + hidden_states, encoder_hidden_states, temb + ) + + # insert here + # pass + + # attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + gate_msa * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + hidden_states, encoder_hidden_states, temb + ) + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + +class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + """ + A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). + + Parameters: + num_attention_heads (`int`, defaults to `30`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `16`): + The number of channels in the output. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_width (`int`, defaults to `90`): + The width of the input latents. + sample_height (`int`, defaults to `60`): + The height of the input latents. + sample_frames (`int`, defaults to `49`): + The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 + instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, + but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with + K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + temporal_compression_ratio (`int`, defaults to `4`): + The compression ratio across the temporal dimension. See documentation for `sample_frames`. + max_text_seq_length (`int`, defaults to `226`): + The maximum sequence length of the input text embeddings. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + timestep_activation_fn (`str`, defaults to `"silu"`): + Activation function to use when generating the timestep embeddings. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-5`): + The epsilon value to use in normalization layers. + spatial_interpolation_scale (`float`, defaults to `1.875`): + Scaling factor to apply in 3D positional embeddings across spatial dimensions. + temporal_interpolation_scale (`float`, defaults to `1.0`): + Scaling factor to apply in 3D positional embeddings across temporal dimensions. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: int = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + is_train_face: bool = False, + is_kps: bool = False, + cross_attn_interval: int = 1, + LFE_num_tokens: int = 32, + LFE_output_dim: int = 768, + LFE_heads: int = 12, + local_face_scale: float = 1.0, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + if not use_rotary_positional_embeddings and use_learned_positional_embeddings: + raise ValueError( + "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " + "embeddings. If you're using a custom model and/or believe this should be supported, please open an " + "issue at https://github.com/huggingface/diffusers/issues." + ) + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) + self.embedding_dropout = nn.Dropout(dropout) + + # 2. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 3. Define spatio-temporal transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + + # 4. Output blocks + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + chunk_dim=1, + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + self.is_train_face = is_train_face + self.is_kps = is_kps + + if is_train_face: + self.inner_dim = inner_dim + self.cross_attn_interval = cross_attn_interval + self.num_ca = num_layers // cross_attn_interval + self.LFE_num_tokens = LFE_num_tokens + self.LFE_output_dim = LFE_output_dim + self.LFE_heads = LFE_heads + self.LFE_final_output_dim = int(self.inner_dim / 3 * 2) + self.local_face_scale = local_face_scale + self._init_face_inputs() + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def _init_face_inputs(self): + device = self.device + weight_dtype = next(self.transformer_blocks.parameters()).dtype + self.local_facial_extractor = LocalFacialExtractor() + self.local_facial_extractor.to(device, dtype=weight_dtype) + self.perceiver_cross_attention = nn.ModuleList([ + PerceiverCrossAttention(dim=self.inner_dim, dim_head=128, heads=16, kv_dim=self.LFE_final_output_dim).to(device, dtype=weight_dtype) for _ in range(self.num_ca) + ]) + + def save_face_modules(self, path: str): + save_dict = { + 'local_facial_extractor': self.local_facial_extractor.state_dict(), + 'perceiver_cross_attention': [ca.state_dict() for ca in self.perceiver_cross_attention], + } + torch.save(save_dict, path) + + def load_face_modules(self, path: str): + checkpoint = torch.load(path, map_location=self.device) + self.local_facial_extractor.load_state_dict(checkpoint['local_facial_extractor']) + for ca, state_dict in zip(self.perceiver_cross_attention, checkpoint['perceiver_cross_attention']): + ca.load_state_dict(state_dict) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + id_cond: Optional[torch.Tensor] = None, + id_vit_hidden: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + # fuse clip and insightface + if self.is_train_face: + assert id_cond is not None and id_vit_hidden is not None + valid_face_emb = self.local_facial_extractor(id_cond, id_vit_hidden) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048]) + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + # torch.Size([1, 226, 4096]) torch.Size([1, 13, 32, 60, 90]) + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # torch.Size([1, 17776, 3072]) + hidden_states = self.embedding_dropout(hidden_states) # torch.Size([1, 17776, 3072]) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072]) + hidden_states = hidden_states[:, text_seq_length:] # torch.Size([1, 17550, 3072]) + + # 3. Transformer blocks + ca_idx = 0 + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + if self.is_train_face: + if i % self.cross_attn_interval == 0 and valid_face_emb is not None: + hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx](valid_face_emb, hidden_states) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072]) + ca_idx += 1 + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + @classmethod + def from_pretrained_cus(cls, pretrained_model_path, subfolder=None, config_path=None, transformer_additional_kwargs={}): + if subfolder: + config_path = config_path or pretrained_model_path + config_file = os.path.join(config_path, subfolder, 'config.json') + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + else: + config_file = os.path.join(config_path or pretrained_model_path, 'config.json') + + print(f"Loading 3D transformer's pretrained weights from {pretrained_model_path} ...") + + # Check if config file exists + if not os.path.isfile(config_file): + raise RuntimeError(f"Configuration file '{config_file}' does not exist") + + # Load the configuration + with open(config_file, "r") as f: + config = json.load(f) + + from diffusers.utils import WEIGHTS_NAME + model = cls.from_config(config, **transformer_additional_kwargs) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file_safetensors = model_file.replace(".bin", ".safetensors") + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + for model_file_safetensors in model_files_safetensors: + _state_dict = load_file(model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size(): + new_shape = model.state_dict()['patch_embed.proj.weight'].size() + if len(new_shape) == 5: + state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone() + state_dict['patch_embed.proj.weight'][:, :, :-1] = 0 + else: + if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]: + model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight'] + model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0 + state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] + else: + model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :] + state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] + + tmp_state_dict = {} + for key in state_dict: + if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): + tmp_state_dict[key] = state_dict[key] + else: + print(key, "Size don't match, skip") + state_dict = tmp_state_dict + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m) + + params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()] + print(f"### Mamba Parameters: {sum(params) / 1e6} M") + + params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] + print(f"### attn1 Parameters: {sum(params) / 1e6} M") + + return model + +if __name__ == '__main__': + device = "cuda:0" + weight_dtype = torch.bfloat16 + pretrained_model_name_or_path = "BestWishYsh/ConsisID-preview" + + transformer_additional_kwargs={ + 'torch_dtype': weight_dtype, + 'revision': None, + 'variant': None, + 'is_train_face': True, + 'is_kps': False, + 'LFE_num_tokens': 32, + 'LFE_output_dim': 768, + 'LFE_heads': 12, + 'cross_attn_interval': 2, + } + + transformer = ConsisIDTransformer3DModel.from_pretrained_cus( + pretrained_model_name_or_path, + subfolder="transformer", + transformer_additional_kwargs=transformer_additional_kwargs, + ) + + transformer.to(device, dtype=weight_dtype) + for param in transformer.parameters(): + param.requires_grad = False + transformer.eval() + + b = 1 + dim = 32 + pixel_values = torch.ones(b, 49, 3, 480, 720).to(device, dtype=weight_dtype) + noisy_latents = torch.ones(b, 13, dim, 60, 90).to(device, dtype=weight_dtype) + target = torch.ones(b, 13, dim, 60, 90).to(device, dtype=weight_dtype) + latents = torch.ones(b, 13, dim, 60, 90).to(device, dtype=weight_dtype) + prompt_embeds = torch.ones(b, 226, 4096).to(device, dtype=weight_dtype) + image_rotary_emb = (torch.ones(17550, 64).to(device, dtype=weight_dtype), torch.ones(17550, 64).to(device, dtype=weight_dtype)) + timesteps = torch.tensor([311]).to(device, dtype=weight_dtype) + id_vit_hidden = [torch.ones([1, 577, 1024]).to(device, dtype=weight_dtype)] * 5 + id_cond = torch.ones(b, 1280).to(device, dtype=weight_dtype) + assert len(timesteps) == b + + model_output = transformer( + hidden_states=noisy_latents, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + image_rotary_emb=image_rotary_emb, + return_dict=False, + id_vit_hidden=id_vit_hidden if id_vit_hidden is not None else None, + id_cond=id_cond if id_cond is not None else None, + )[0] + + print(model_output) + # transformer.save_pretrained(os.path.join("./test_ckpt", "transformer")) + diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..08260b3fe2e28c359a0e94d5fc3572e16398a1a4 --- /dev/null +++ b/models/utils.py @@ -0,0 +1,273 @@ +import cv2 +import math +import numpy as np +from PIL import Image + +import torch +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import normalize, resize +from transformers import T5EncoderModel, T5Tokenizer +from typing import List, Optional, Tuple, Union +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid + + +def tensor_to_pil(src_img_tensor): + img = src_img_tensor.clone().detach() + if img.dtype == torch.bfloat16: + img = img.to(torch.float32) + img = img.cpu().numpy() + img = np.transpose(img, (1, 2, 0)) + img = img.astype(np.uint8) + pil_image = Image.fromarray(img) + return pil_image + + +def _get_t5_prompt_embeds( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + +def encode_prompt( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = _get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=text_input_ids, + ) + return prompt_embeds + + +def compute_prompt_embeddings( + tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False +): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + return prompt_embeds + + +def prepare_rotary_positional_embeddings( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, +) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * patch_size) + grid_width = width // (vae_scale_factor_spatial * patch_size) + base_size_width = base_width // (vae_scale_factor_spatial * patch_size) + base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + return _totensor(imgs, bgr2rgb, float32) + + +def to_gray(img): + x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] + x = x.repeat(1, 3, 1, 1) + return x + + +def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + + +def process_face_embeddings(face_helper, clip_vision_model, handler_ante, eva_transform_mean, eva_transform_std, app, device, weight_dtype, image, original_id_image=None, is_align_face=True, cal_uncond=False): + """ + Args: + image: numpy rgb image, range [0, 255] + """ + face_helper.clean_all() + image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # (724, 502, 3) + # get antelopev2 embedding + face_info = app.get(image_bgr) + if len(face_info) > 0: + face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[ + -1 + ] # only use the maximum face + id_ante_embedding = face_info['embedding'] # (512,) + face_kps = face_info['kps'] + else: + id_ante_embedding = None + face_kps = None + + # using facexlib to detect and align face + face_helper.read_image(image_bgr) + face_helper.get_face_landmarks_5(only_center_face=True) + if face_kps is None: + face_kps = face_helper.all_landmarks_5[0] + face_helper.align_warp_face() + if len(face_helper.cropped_faces) == 0: + raise RuntimeError('facexlib align face fail') + align_face = face_helper.cropped_faces[0] # (512, 512, 3) # RGB + + # incase insightface didn't detect face + if id_ante_embedding is None: + print('fail to detect face using insightface, extract embedding on align face') + id_ante_embedding = handler_ante.get_feat(align_face) + + id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512]) + if id_ante_embedding.ndim == 1: + id_ante_embedding = id_ante_embedding.unsqueeze(0) # torch.Size([1, 512]) + + # parsing + if is_align_face: + input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512]) + input = input.to(device) + parsing_out = face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] + parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512]) + bg_label = [0, 16, 18, 7, 8, 9, 14, 15] + bg = sum(parsing_out == i for i in bg_label).bool() + white_image = torch.ones_like(input) # torch.Size([1, 3, 512, 512]) + # only keep the face features + return_face_features_image = torch.where(bg, white_image, to_gray(input)) # torch.Size([1, 3, 512, 512]) + return_face_features_image_2 = torch.where(bg, white_image, input) # torch.Size([1, 3, 512, 512]) + else: + original_image_bgr = cv2.cvtColor(original_id_image, cv2.COLOR_RGB2BGR) + input = img2tensor(original_image_bgr, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512]) + input = input.to(device) + return_face_features_image = return_face_features_image_2 = input + + # transform img before sending to eva-clip-vit + face_features_image = resize(return_face_features_image, clip_vision_model.image_size, + InterpolationMode.BICUBIC) # torch.Size([1, 3, 336, 336]) + face_features_image = normalize(face_features_image, eva_transform_mean, eva_transform_std) + id_cond_vit, id_vit_hidden = clip_vision_model(face_features_image.to(weight_dtype), return_all_features=False, return_hidden=True, shuffle=False) # torch.Size([1, 768]), list(torch.Size([1, 577, 1024])) + id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True) + id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm) + + id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280]) + + return id_cond, id_vit_hidden, return_face_features_image_2, face_kps # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024])) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7ff5f821a4f7150cc734162555f38088db564cc6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,36 @@ +torch==2.5.1 +torchaudio==2.5.1 +torchvision==0.20.1 +xformers==0.0.28.post3 +onnx==1.17.0 +onnxruntime-gpu==1.19.2 +deepspeed==0.15.2 +accelerate==1.1.1 +diffusers==0.31.0 +transformers==4.46.3 +tokenizers==0.20.1 +peft==0.12.0 +decord==0.6.0 +sentencepiece==0.2.0 +opencv-python==4.10.0.84 +pyfacer==0.0.4 +numpy==1.26.4 +numba==0.60.0 +insightface==0.7.3 +huggingface-hub==0.26.1 +facexlib==0.3.0 +timm==1.0.9 +func_timeout==4.3.5 +tensorboard==2.17.1 +gradio==5.6.0 +spaces==0.30.4 +pillow==10.4.0 +spandrel==0.4.0 +scikit-video==1.1.11 +moviepy +wandb +imageio-ffmpeg +ftfy +Jinja2 +einops +nvitop \ No newline at end of file diff --git a/util/dataloader.py b/util/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..e701206f7c2fab4dfe1caa5f9440d4ee5816c10d --- /dev/null +++ b/util/dataloader.py @@ -0,0 +1,1010 @@ +import os +import gc +import cv2 +import json +import math +import decord +import random +import numpy as np +from PIL import Image +from tqdm import tqdm +from decord import VideoReader +from contextlib import contextmanager +from func_timeout import FunctionTimedOut +from typing import Optional, Sized, Iterator + +import torch +from torch.utils.data import Dataset, Sampler +import torch.nn.functional as F +from torchvision.transforms import ToPILImage +from torchvision import transforms +from accelerate.logging import get_logger + +logger = get_logger(__name__) + +import threading +log_lock = threading.Lock() + +def log_error_to_file(error_message, video_path): + with log_lock: + with open("error_log.txt", "a") as f: + f.write(f"Error: {error_message}\n") + f.write(f"Video Path: {video_path}\n") + f.write("-" * 50 + "\n") + +def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + +@contextmanager +def VideoReader_contextmanager(*args, **kwargs): + vr = VideoReader(*args, **kwargs) + try: + yield vr + finally: + del vr + gc.collect() + +def get_valid_segments(valid_frame, tolerance=5): + valid_positions = sorted(set(valid_frame['face']).union(set(valid_frame['head']))) + + valid_segments = [] + current_segment = [valid_positions[0]] + + for i in range(1, len(valid_positions)): + if valid_positions[i] - valid_positions[i - 1] <= tolerance: + current_segment.append(valid_positions[i]) + else: + valid_segments.append(current_segment) + current_segment = [valid_positions[i]] + + if current_segment: + valid_segments.append(current_segment) + + return valid_segments + + +def get_frame_indices_adjusted_for_face(valid_frames, n_frames): + valid_length = len(valid_frames) + if valid_length >= n_frames: + return valid_frames[:n_frames] + + additional_frames_needed = n_frames - valid_length + repeat_indices = [] + + for i in range(additional_frames_needed): + index_to_repeat = i % valid_length + repeat_indices.append(valid_frames[index_to_repeat]) + + all_indices = valid_frames + repeat_indices + all_indices.sort() + + return all_indices + + +def generate_frame_indices_for_face(n_frames, sample_stride, valid_frame, tolerance=7, skip_frames_start_percent=0.0, skip_frames_end_percent=1.0, skip_frames_start=0, skip_frames_end=0): + valid_segments = get_valid_segments(valid_frame, tolerance) + selected_segment = max(valid_segments, key=len) + + valid_length = len(selected_segment) + if skip_frames_start_percent != 0.0 or skip_frames_end_percent != 1.0: + # print("use skip frame percent") + valid_start = int(valid_length * skip_frames_start_percent) + valid_end = int(valid_length * skip_frames_end_percent) + elif skip_frames_start != 0 or skip_frames_end != 0: + # print("use skip frame") + valid_start = skip_frames_start + valid_end = valid_length - skip_frames_end + else: + # print("no use skip frame") + valid_start = 0 + valid_end = valid_length + + if valid_length <= n_frames: + return get_frame_indices_adjusted_for_face(selected_segment, n_frames), valid_length + else: + adjusted_length = valid_end - valid_start + if adjusted_length <= 0: + print(f"video_length: {valid_length}, adjusted_length: {adjusted_length}, valid_start:{valid_start}, skip_frames_end: {valid_end}") + raise ValueError("Skipping too many frames results in no frames left to sample.") + + clip_length = min(adjusted_length, (n_frames - 1) * sample_stride + 1) + start_idx_position = random.randint(valid_start, valid_end - clip_length) + start_frame = selected_segment[start_idx_position] + + selected_frames = [] + for i in range(n_frames): + next_frame = start_frame + i * sample_stride + if next_frame in selected_segment: + selected_frames.append(next_frame) + else: + break + + if len(selected_frames) < n_frames: + return get_frame_indices_adjusted_for_face(selected_frames, n_frames), len(selected_frames) + + return selected_frames, len(selected_frames) + +def frame_has_required_confidence(bbox_data, frame, ID, conf_threshold=0.88): + frame_str = str(frame) + if frame_str not in bbox_data: + return False + + frame_data = bbox_data[frame_str] + + face_conf = any( + item['confidence'] > conf_threshold and item['new_track_id'] == ID + for item in frame_data.get('face', []) + ) + + head_conf = any( + item['confidence'] > conf_threshold and item['new_track_id'] == ID + for item in frame_data.get('head', []) + ) + + return face_conf and head_conf + +def select_mask_frames_from_index(batch_frame, original_batch_frame, valid_id, corresponding_data, control_sam2_frame, + valid_frame, bbox_data, base_dir, min_distance=3, min_frames=1, max_frames=5, + mask_type='face', control_mask_type='head', dense_masks=False, + ensure_control_frame=True): + """ + Selects frames with corresponding mask images while ensuring a minimum distance constraint between frames, + and that the frames exist in both batch_frame and valid_frame. + + Parameters: + base_path (str): Base directory where the JSON files and mask results are located. + min_distance (int): Minimum distance between selected frames. + min_frames (int): Minimum number of frames to select. + max_frames (int): Maximum number of frames to select. + mask_type (str): Type of mask to select frames for ('face' or 'head'). + control_mask_type (str): Type of mask used for control frame selection ('face' or 'head'). + + Returns: + dict: A dictionary where keys are IDs and values are lists of selected mask PNG paths. + """ + # Helper function to randomly select frames with at least X frames apart + def select_frames_with_distance_constraint(frames, num_frames, min_distance, control_frame, bbox_data, ID, + ensure_control_frame=True, fallback=True): + """ + Selects frames with a minimum distance constraint. If not enough frames can be selected, a fallback plan is applied. + + Parameters: + frames (list): List of frame indices to select from. + num_frames (int): Number of frames to select. + min_distance (int): Minimum distance between selected frames. + control_frame (int): The control frame that must always be included. + fallback (bool): Whether to apply a fallback strategy if not enough frames meet the distance constraint. + + Returns: + list: List of selected frames. + """ + conf_thresholds = [0.95, 0.94, 0.93, 0.92, 0.91, 0.90] + if ensure_control_frame: + selected_frames = [control_frame] # Ensure control frame is always included + else: + valid_initial_frames = [] + for conf_threshold in conf_thresholds: + valid_initial_frames = [ + f for f in frames + if frame_has_required_confidence(bbox_data, f, ID, conf_threshold=conf_threshold) + ] + if valid_initial_frames: + break + if valid_initial_frames: + selected_frames = [random.choice(valid_initial_frames)] + else: + # If no frame meets the initial confidence, fall back to a random frame (or handle as per your preference) + selected_frames = [random.choice(frames)] + + available_frames = [f for f in frames if f != selected_frames[0]] # Exclude control frame for random selection + + random.shuffle(available_frames) # Shuffle to introduce randomness + + while available_frames and len(selected_frames) < num_frames: + last_selected_frame = selected_frames[-1] + + valid_choices = [] + for conf_threshold in conf_thresholds: + valid_choices = [ + f for f in available_frames + if abs(f - last_selected_frame) >= min_distance and + frame_has_required_confidence(bbox_data, f, ID, conf_threshold=conf_threshold) + ] + if valid_choices: + break + + if valid_choices: + frame = random.choice(valid_choices) + available_frames.remove(frame) + selected_frames.append(frame) + else: + if fallback: + # Fallback strategy: uniformly distribute remaining frames if distance constraint cannot be met + remaining_needed = num_frames - len(selected_frames) + remaining_frames = available_frames[:remaining_needed] + + # Distribute the remaining frames evenly if possible + if remaining_frames: + step = max(1, len(remaining_frames) // remaining_needed) + evenly_selected = remaining_frames[::step][:remaining_needed] + selected_frames.extend(evenly_selected) + break + else: + break # No valid choices remain and no fallback strategy is allowed + + if len(selected_frames) < num_frames: + return None + + return selected_frames + + # Convert batch_frame list to a set to remove duplicates + batch_frame_set = set(batch_frame) + + # Dictionary to store selected mask PNGs + selected_masks_dict = {} + selected_bboxs_dict = {} + dense_masks_dict = {} + selected_frames_dict = {} + + # ID + try: + mask_valid_frames = valid_frame[mask_type] # Select frames based on the specified mask type + control_valid_frames = valid_frame[control_mask_type] # Control frames for control_mask_type + except KeyError: + if mask_type not in valid_frame.keys(): + print(f"no valid {mask_type}") + if control_mask_type not in valid_frame.keys(): + print(f"no valid {control_mask_type}") + + # Get the control frame for the control mask type + control_frame = control_sam2_frame[valid_id][control_mask_type] + + # Filter frames to only those which are in both valid_frame and batch_frame_set + valid_frames = [] + # valid_frames = [frame for frame in mask_valid_frames if frame in control_valid_frames and frame in batch_frame_set] + for frame in mask_valid_frames: + if frame in control_valid_frames and frame in batch_frame_set: + # Check if bbox_data has 'head' or 'face' for the frame + if str(frame) in bbox_data: + frame_data = bbox_data[str(frame)] + if 'head' in frame_data or 'face' in frame_data: + valid_frames.append(frame) + + # Ensure the control frame is included in the valid frames + if ensure_control_frame and (control_frame not in valid_frames): + valid_frames.append(control_frame) + + # Select a random number of frames between min_frames and max_frames + num_frames_to_select = random.randint(min_frames, max_frames) + selected_frames = select_frames_with_distance_constraint(valid_frames, num_frames_to_select, min_distance, + control_frame, bbox_data, valid_id, ensure_control_frame) + + # Store the selected frames as mask PNGs and bbox + selected_masks_dict[valid_id] = [] + selected_bboxs_dict[valid_id] = [] + + # Initialize the dense_masks_dict entry for the current ID + dense_masks_dict[valid_id] = [] + + # Store the selected frames in the dictionary + selected_frames_dict[valid_id] = selected_frames + + if dense_masks: + for frame in original_batch_frame: + mask_data_path = f"{base_dir}/{valid_id}/annotated_frame_{int(frame):05d}.png" + mask_array = np.array(Image.open(mask_data_path)) + binary_mask = np.where(mask_array > 0, 1, 0).astype(np.uint8) + dense_masks_dict[valid_id].append(binary_mask) + + for frame in selected_frames: + mask_data_path = f"{base_dir}/{valid_id}/annotated_frame_{frame:05d}.png" + mask_array = np.array(Image.open(mask_data_path)) + binary_mask = np.where(mask_array > 0, 1, 0).astype(np.uint8) + selected_masks_dict[valid_id].append(binary_mask) + + try: + for item in bbox_data[f"{frame}"]["head"]: + if item['new_track_id'] == int(valid_id): + temp_bbox = item['box'] + break + except (KeyError, StopIteration): + try: + for item in bbox_data[f"{frame}"]["face"]: + if item['new_track_id'] == int(valid_id): + temp_bbox = item['box'] + break + except (KeyError, StopIteration): + temp_bbox = None + + selected_bboxs_dict[valid_id].append(temp_bbox) + + return selected_frames_dict, selected_masks_dict, selected_bboxs_dict, dense_masks_dict + +def pad_tensor(tensor, target_size, dim=0): + padding_size = target_size - tensor.size(dim) + if padding_size > 0: + pad_shape = list(tensor.shape) + pad_shape[dim] = padding_size + padding_tensor = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + return torch.cat([tensor, padding_tensor], dim=dim) + else: + return tensor[:target_size] + +def crop_images(selected_frame_index, selected_bboxs_dict, video_reader, return_ori=False): + """ + Crop images based on given bounding boxes and frame indices from a video. + + Args: + selected_frame_index (list): List of frame indices to be cropped. + selected_bboxs_dict (list of dict): List of dictionaries, each containing 'x1', 'y1', 'x2', 'y2' bounding box coordinates. + video_reader (VideoReader or list of numpy arrays): Video frames accessible by index, where each frame is a numpy array (H, W, C). + + Returns: + list: A list of cropped images in PIL Image format. + """ + expanded_cropped_images = [] + original_cropped_images = [] + for frame_idx, bbox in zip(selected_frame_index, selected_bboxs_dict): + # Get the specific frame from the video reader using the frame index + frame = video_reader[frame_idx] # torch.tensor # (H, W, C) + + # Extract bounding box coordinates and convert them to integers + x1, y1, x2, y2 = int(bbox['x1']), int(bbox['y1']), int(bbox['x2']), int(bbox['y2']) + # Crop to minimize the bounding box to a square + width = x2 - x1 # Calculate the width of the bounding box + height = y2 - y1 # Calculate the height of the bounding box + side_length = max(width, height) # Determine the side length of the square (max of width or height) + + # Calculate the center of the bounding box + center_x = (x1 + x2) // 2 + center_y = (y1 + y2) // 2 + + # Calculate new coordinates for the square region centered around the original bounding box + new_x1 = max(0, center_x - side_length // 2) # Ensure x1 is within image bounds + new_y1 = max(0, center_y - side_length // 2) # Ensure y1 is within image bounds + new_x2 = min(frame.shape[1], new_x1 + side_length) # Ensure x2 does not exceed image width + new_y2 = min(frame.shape[0], new_y1 + side_length) # Ensure y2 does not exceed image height + + # Adjust coordinates if the cropped area is smaller than the desired side length + # Ensure final width and height are equal, keeping it a square + actual_width = new_x2 - new_x1 + actual_height = new_y2 - new_y1 + + if actual_width < side_length: + # Adjust x1 or x2 to ensure the correct side length, while staying in bounds + if new_x1 == 0: + new_x2 = min(frame.shape[1], new_x1 + side_length) + else: + new_x1 = max(0, new_x2 - side_length) + + if actual_height < side_length: + # Adjust y1 or y2 to ensure the correct side length, while staying in bounds + if new_y1 == 0: + new_y2 = min(frame.shape[0], new_y1 + side_length) + else: + new_y1 = max(0, new_y2 - side_length) + + # Expand the square by 20% + expansion_ratio = 0.2 # Define the expansion ratio + expansion_amount = int(side_length * expansion_ratio) # Calculate the number of pixels to expand by + + # Calculate expanded coordinates, ensuring they stay within image bounds + expanded_x1 = max(0, new_x1 - expansion_amount) # Expand left, ensuring x1 is within bounds + expanded_y1 = max(0, new_y1 - expansion_amount) # Expand up, ensuring y1 is within bounds + expanded_x2 = min(frame.shape[1], new_x2 + expansion_amount) # Expand right, ensuring x2 does not exceed bounds + expanded_y2 = min(frame.shape[0], new_y2 + expansion_amount) # Expand down, ensuring y2 does not exceed bounds + + # Ensure the expanded area is still a square + expanded_width = expanded_x2 - expanded_x1 + expanded_height = expanded_y2 - expanded_y1 + final_side_length = min(expanded_width, expanded_height) + + # Adjust to ensure square shape if necessary + if expanded_width != expanded_height: + if expanded_width > expanded_height: + expanded_x2 = expanded_x1 + final_side_length + else: + expanded_y2 = expanded_y1 + final_side_length + + expanded_cropped_rgb_tensor = frame[expanded_y1:expanded_y2, expanded_x1:expanded_x2, :] + expanded_cropped_rgb = Image.fromarray(np.array(expanded_cropped_rgb_tensor)).convert('RGB') + expanded_cropped_images.append(expanded_cropped_rgb) + + if return_ori: + original_cropped_rgb_tensor = frame[new_y1:new_y2, new_x1:new_x2, :] + original_cropped_rgb = Image.fromarray(np.array(original_cropped_rgb_tensor)).convert('RGB') + original_cropped_images.append(original_cropped_rgb) + return expanded_cropped_images, original_cropped_images + + return expanded_cropped_images, None + +def process_cropped_images(expand_images_pil, original_images_pil, target_size=(480, 480)): + """ + Process a list of cropped images in PIL format. + + Parameters: + expand_images_pil (list of PIL.Image): List of cropped images in PIL format. + target_size (tuple of int): The target size for resizing images, default is (480, 480). + + Returns: + torch.Tensor: A tensor containing the processed images. + """ + expand_face_imgs = [] + original_face_imgs = [] + if len(original_images_pil) != 0: + for expand_img, original_img in zip(expand_images_pil, original_images_pil): + expand_resized_img = expand_img.resize(target_size, Image.LANCZOS) + expand_src_img = np.array(expand_resized_img) + expand_src_img = np.transpose(expand_src_img, (2, 0, 1)) + expand_src_img = torch.from_numpy(expand_src_img).unsqueeze(0).float() + expand_face_imgs.append(expand_src_img) + + original_resized_img = original_img.resize(target_size, Image.LANCZOS) + original_src_img = np.array(original_resized_img) + original_src_img = np.transpose(original_src_img, (2, 0, 1)) + original_src_img = torch.from_numpy(original_src_img).unsqueeze(0).float() + original_face_imgs.append(original_src_img) + + expand_face_imgs = torch.cat(expand_face_imgs, dim=0) + original_face_imgs = torch.cat(original_face_imgs, dim=0) + else: + for expand_img in expand_images_pil: + expand_resized_img = expand_img.resize(target_size, Image.LANCZOS) + expand_src_img = np.array(expand_resized_img) + expand_src_img = np.transpose(expand_src_img, (2, 0, 1)) + expand_src_img = torch.from_numpy(expand_src_img).unsqueeze(0).float() + expand_face_imgs.append(expand_src_img) + expand_face_imgs = torch.cat(expand_face_imgs, dim=0) + original_face_imgs = None + + return expand_face_imgs, original_face_imgs + +class RandomSampler(Sampler[int]): + r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. + + If with replacement, then user can specify :attr:`num_samples` to draw. + + Args: + data_source (Dataset): dataset to sample from + replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` + num_samples (int): number of samples to draw, default=`len(dataset)`. + generator (Generator): Generator used in sampling. + """ + + data_source: Sized + replacement: bool + + def __init__(self, data_source: Sized, replacement: bool = False, + num_samples: Optional[int] = None, generator=None) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + self._pos_start = 0 + + if not isinstance(self.replacement, bool): + raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + if self.generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator + + if self.replacement: + for _ in range(self.num_samples // 32): + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() + else: + for _ in range(self.num_samples // n): + xx = torch.randperm(n, generator=generator).tolist() + if self._pos_start >= n: + self._pos_start = 0 + print("xx top 10", xx[:10], self._pos_start) + for idx in range(self._pos_start, n): + yield xx[idx] + self._pos_start = (self._pos_start + 1) % n + self._pos_start = 0 + yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] + + def __len__(self) -> int: + return self.num_samples + +class SequentialSampler(Sampler[int]): + r"""Samples elements sequentially, always in the same order. + + Args: + data_source (Dataset): dataset to sample from + """ + + data_source: Sized + + def __init__(self, data_source: Sized) -> None: + self.data_source = data_source + self._pos_start = 0 + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + for idx in range(self._pos_start, n): + yield idx + self._pos_start = (self._pos_start + 1) % n + self._pos_start = 0 + + def __len__(self) -> int: + return len(self.data_source) + +class ConsisID_Dataset(Dataset): + def __init__( + self, + instance_data_root: Optional[str] = None, + id_token: Optional[str] = None, + height=480, + width=640, + max_num_frames=49, + sample_stride=3, + skip_frames_start_percent=0.0, + skip_frames_end_percent=1.0, + skip_frames_start=0, + skip_frames_end=0, + text_drop_ratio=-1, + is_train_face=False, + is_single_face=False, + miss_tolerance=6, + min_distance=3, + min_frames=1, + max_frames=5, + is_cross_face=False, + is_reserve_face=False, + ): + self.id_token = id_token or "" + + # ConsisID + self.skip_frames_start_percent = skip_frames_start_percent + self.skip_frames_end_percent = skip_frames_end_percent + self.skip_frames_start = skip_frames_start + self.skip_frames_end = skip_frames_end + self.is_train_face = is_train_face + self.is_single_face = is_single_face + + if is_train_face: + self.miss_tolerance = miss_tolerance + self.min_distance = min_distance + self.min_frames = min_frames + self.max_frames = max_frames + self.is_cross_face = is_cross_face + self.is_reserve_face = is_reserve_face + + # Loading annotations from files + print(f"loading annotations from {instance_data_root} ...") + with open(instance_data_root, 'r') as f: + folder_anno = [i.strip().split(',') for i in f.readlines() if len(i.strip()) > 0] + + self.instance_prompts = [] + self.instance_video_paths = [] + self.instance_annotation_base_paths = [] + for sub_root, anno, anno_base in tqdm(folder_anno): + print(anno) + self.instance_annotation_base_paths.append(anno_base) + with open(anno, 'r') as f: + sub_list = json.load(f) + for i in tqdm(sub_list): + path = os.path.join(sub_root, os.path.basename(i['path'])) + cap = i.get('cap', None) + fps = i.get('fps', 0) + duration = i.get('duration', 0) + + if fps * duration < 49.0: + continue + + self.instance_prompts.append(cap) + self.instance_video_paths.append(path) + + self.num_instance_videos = len(self.instance_video_paths) + + self.text_drop_ratio = text_drop_ratio + + # Video params + self.sample_stride = sample_stride + self.max_num_frames = max_num_frames + self.height = height + self.width = width + + def _get_frame_indices_adjusted(self, video_length, n_frames): + indices = list(range(video_length)) + additional_frames_needed = n_frames - video_length + + repeat_indices = [] + for i in range(additional_frames_needed): + index_to_repeat = i % video_length + repeat_indices.append(indices[index_to_repeat]) + + all_indices = indices + repeat_indices + all_indices.sort() + + return all_indices + + + def _generate_frame_indices(self, video_length, n_frames, sample_stride, skip_frames_start_percent=0.0, skip_frames_end_percent=1.0, skip_frames_start=0, skip_frames_end=0): + if skip_frames_start_percent != 0.0 or skip_frames_end_percent != 1.0: + print("use skip frame percent") + valid_start = int(video_length * skip_frames_start_percent) + valid_end = int(video_length * skip_frames_end_percent) + elif skip_frames_start != 0 or skip_frames_end != 0: + print("use skip frame") + valid_start = skip_frames_start + valid_end = video_length - skip_frames_end + else: + print("no use skip frame") + valid_start = 0 + valid_end = video_length + + adjusted_length = valid_end - valid_start + + if adjusted_length <= 0: + print(f"video_length: {video_length}, adjusted_length: {adjusted_length}, valid_start:{valid_start}, skip_frames_end: {valid_end}") + raise ValueError("Skipping too many frames results in no frames left to sample.") + + if video_length <= n_frames: + return self._get_frame_indices_adjusted(video_length, n_frames) + else: + # clip_length = min(video_length, (n_frames - 1) * sample_stride + 1) + # start_idx = random.randint(0, video_length - clip_length) + # frame_indices = np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist() + + clip_length = min(adjusted_length, (n_frames - 1) * sample_stride + 1) + start_idx = random.randint(valid_start, valid_end - clip_length) + frame_indices = np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist() + return frame_indices + + def _short_resize_and_crop(self, frames, target_width, target_height): + """ + Resize frames and crop to the specified size. + + Args: + frames (torch.Tensor): Input frames of shape [T, H, W, C]. + target_width (int): Desired width. + target_height (int): Desired height. + + Returns: + torch.Tensor: Cropped frames of shape [T, target_height, target_width, C]. + """ + T, H, W, C = frames.shape + aspect_ratio = W / H + + # Determine new dimensions ensuring they are at least target size + if aspect_ratio > target_width / target_height: + new_width = target_width + new_height = int(target_width / aspect_ratio) + if new_height < target_height: + new_height = target_height + new_width = int(target_height * aspect_ratio) + else: + new_height = target_height + new_width = int(target_height * aspect_ratio) + if new_width < target_width: + new_width = target_width + new_height = int(target_width / aspect_ratio) + + resize_transform = transforms.Resize((new_height, new_width)) + crop_transform = transforms.CenterCrop((target_height, target_width)) + + frames_tensor = frames.permute(0, 3, 1, 2) # (T, H, W, C) -> (T, C, H, W) + resized_frames = resize_transform(frames_tensor) + cropped_frames = crop_transform(resized_frames) + sample = cropped_frames.permute(0, 2, 3, 1) + + return sample + + def _resize_with_aspect_ratio(self, frames, target_width, target_height): + """ + Resize frames while maintaining the aspect ratio by padding or cropping. + + Args: + frames (torch.Tensor): Input frames of shape [T, H, W, C]. + target_width (int): Desired width. + target_height (int): Desired height. + + Returns: + torch.Tensor: Resized and padded frames of shape [T, target_height, target_width, C]. + """ + T, frame_height, frame_width, C = frames.shape + aspect_ratio = frame_width / frame_height # 1.77, 1280 720 -> 720 406 + target_aspect_ratio = target_width / target_height # 1.50, 720 480 -> + + # If the frame is wider than the target, resize based on width + if aspect_ratio > target_aspect_ratio: + new_width = target_width + new_height = int(target_width / aspect_ratio) + else: + new_height = target_height + new_width = int(target_height * aspect_ratio) + + # Resize using batch processing + frames = frames.permute(0, 3, 1, 2) # [T, C, H, W] + frames = F.interpolate(frames, size=(new_height, new_width), mode='bilinear', align_corners=False) + + # Calculate padding + pad_top = (target_height - new_height) // 2 + pad_bottom = target_height - new_height - pad_top + pad_left = (target_width - new_width) // 2 + pad_right = target_width - new_width - pad_left + + # Apply padding + frames = F.pad(frames, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0) + + frames = frames.permute(0, 2, 3, 1) # [T, H, W, C] + + return frames + + + def _save_frame(self, frame, name="1.png"): + # [H, W, C] -> [C, H, W] + img = frame + img = img.permute(2, 0, 1) + to_pil = ToPILImage() + img = to_pil(img) + img.save(name) + + + def _save_video(self, torch_frames, name="output.mp4"): + from moviepy.editor import ImageSequenceClip + frames_np = torch_frames.cpu().numpy() + if frames_np.dtype != 'uint8': + frames_np = frames_np.astype('uint8') + frames_list = [frame for frame in frames_np] + desired_fps = 24 + clip = ImageSequenceClip(frames_list, fps=desired_fps) + clip.write_videofile(name, codec="libx264") + + + def get_batch(self, idx): + decord.bridge.set_bridge("torch") + + video_dir = self.instance_video_paths[idx] + text = self.instance_prompts[idx] + + train_transforms = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), + ] + ) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + video_num_frames = len(video_reader) + + if self.is_train_face: + reserve_face_imgs = None + file_base_name = os.path.basename(video_dir.replace(".mp4", "")) + + anno_base_path = self.instance_annotation_base_paths[idx] + valid_frame_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "valid_frame.json") + control_sam2_frame_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "control_sam2_frame.json") + corresponding_data_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "corresponding_data.json") + masks_data_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "tracking_mask_results") + bboxs_data_path = os.path.join(anno_base_path, "refine_bbox_jsons", f"{file_base_name}.json") + + with open(corresponding_data_path, 'r') as f: + corresponding_data = json.load(f) + + with open(control_sam2_frame_path, 'r') as f: + control_sam2_frame = json.load(f) + + with open(valid_frame_path, 'r') as f: + valid_frame = json.load(f) + + with open(bboxs_data_path, 'r') as f: + bbox_data = json.load(f) + + if self.is_single_face: + if len(corresponding_data) != 1: + raise ValueError(f"Using single face, but {idx} is multi person.") + + # get random valid id + valid_ids = [] + backup_ids = [] + for id_key, data in corresponding_data.items(): + if 'face' in data and 'head' in data: + valid_ids.append(id_key) + + valid_id = random.choice(valid_ids) if valid_ids else (random.choice(backup_ids) if backup_ids else None) + if valid_id is None: + raise ValueError("No valid ID found: both valid_ids and backup_ids are empty.") + + # get video + total_index = list(range(video_num_frames)) + batch_index, _ = generate_frame_indices_for_face(self.max_num_frames, self.sample_stride, valid_frame[valid_id], + self.miss_tolerance, self.skip_frames_start_percent, self.skip_frames_end_percent, + self.skip_frames_start, self.skip_frames_end) + + if self.is_cross_face: + remaining_batch_index_index = [i for i in total_index if i not in batch_index] + try: + selected_frame_index, selected_masks_dict, selected_bboxs_dict, dense_masks_dict = select_mask_frames_from_index( + remaining_batch_index_index, + batch_index, valid_id, + corresponding_data, control_sam2_frame, + valid_frame[valid_id], bbox_data, masks_data_path, + min_distance=self.min_distance, min_frames=self.min_frames, + max_frames=self.max_frames, dense_masks=True, + ensure_control_frame=False, + ) + except: + selected_frame_index, selected_masks_dict, selected_bboxs_dict, dense_masks_dict = select_mask_frames_from_index( + batch_index, + batch_index, valid_id, + corresponding_data, control_sam2_frame, + valid_frame[valid_id], bbox_data, masks_data_path, + min_distance=self.min_distance, min_frames=self.min_frames, + max_frames=self.max_frames, dense_masks=True, + ensure_control_frame=False, + ) + else: + selected_frame_index, selected_masks_dict, selected_bboxs_dict, dense_masks_dict = select_mask_frames_from_index( + batch_index, + batch_index, valid_id, + corresponding_data, control_sam2_frame, + valid_frame[valid_id], bbox_data, masks_data_path, + min_distance=self.min_distance, min_frames=self.min_frames, + max_frames=self.max_frames, dense_masks=True, + ensure_control_frame=True, + ) + if self.is_reserve_face: + reserve_frame_index, _, reserve_bboxs_dict, _ = select_mask_frames_from_index( + batch_index, + batch_index, valid_id, + corresponding_data, control_sam2_frame, + valid_frame[valid_id], bbox_data, masks_data_path, + min_distance=3, min_frames=4, + max_frames=4, dense_masks=False, + ensure_control_frame=False, + ) + + # get mask and aligned_face_img + selected_frame_index = selected_frame_index[valid_id] + valid_frame = valid_frame[valid_id] + selected_masks_dict = selected_masks_dict[valid_id] + selected_bboxs_dict = selected_bboxs_dict[valid_id] + dense_masks_dict = dense_masks_dict[valid_id] + + if self.is_reserve_face: + reserve_frame_index = reserve_frame_index[valid_id] + reserve_bboxs_dict = reserve_bboxs_dict[valid_id] + + selected_masks_tensor = torch.stack([torch.tensor(mask) for mask in selected_masks_dict]) + temp_dense_masks_tensor = torch.stack([torch.tensor(mask) for mask in dense_masks_dict]) + dense_masks_tensor = self._short_resize_and_crop(temp_dense_masks_tensor.unsqueeze(-1), self.width, self.height).squeeze(-1) # [T, H, W] -> [T, H, W, 1] -> [T, H, W] + + expand_images_pil, original_images_pil = crop_images(selected_frame_index, selected_bboxs_dict, video_reader, return_ori=True) + expand_face_imgs, original_face_imgs = process_cropped_images(expand_images_pil, original_images_pil, target_size=(480, 480)) + if self.is_reserve_face: + reserve_images_pil, _ = crop_images(reserve_frame_index, reserve_bboxs_dict, video_reader, return_ori=False) + reserve_face_imgs, _ = process_cropped_images(reserve_images_pil, [], target_size=(480, 480)) + + if len(expand_face_imgs) == 0 or len(original_face_imgs) == 0: + raise ValueError(f"No face detected in input image pool") + + # post process id related data + expand_face_imgs = pad_tensor(expand_face_imgs, self.max_frames, dim=0) + original_face_imgs = pad_tensor(original_face_imgs, self.max_frames, dim=0) + selected_frame_index = torch.tensor(selected_frame_index) # torch.Size(([15, 13]) [N1] + selected_frame_index = pad_tensor(selected_frame_index, self.max_frames, dim=0) + else: + batch_index = self._generate_frame_indices(video_num_frames, self.max_num_frames, self.sample_stride, + self.skip_frames_start_percent, self.skip_frames_end_percent, + self.skip_frames_start, self.skip_frames_end) + + try: + frames = video_reader.get_batch(batch_index) # torch [T, H, W, C] + frames = self._short_resize_and_crop(frames, self.width, self.height) # [T, H, W, C] + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + # Apply training transforms in batch + frames = frames.float() + frames = train_transforms(frames) + pixel_values = frames.permute(0, 3, 1, 2).contiguous() # [T, C, H, W] + del video_reader + + # Random use no text generation + if random.random() < self.text_drop_ratio: + text = '' + + if self.is_train_face: + return pixel_values, text, 'video', video_dir, expand_face_imgs, dense_masks_tensor, selected_frame_index, reserve_face_imgs, original_face_imgs + else: + return pixel_values, text, 'video', video_dir + + def __len__(self): + return self.num_instance_videos + + def __getitem__(self, idx): + sample = {} + if self.is_train_face: + pixel_values, cap, data_type, video_dir, expand_face_imgs, dense_masks_tensor, selected_frame_index, reserve_face_imgs, original_face_imgs = self.get_batch(idx) + sample["instance_prompt"] = self.id_token + cap + sample["instance_video"] = pixel_values + sample["video_path"] = video_dir + if self.is_train_face: + sample["expand_face_imgs"] = expand_face_imgs + sample["dense_masks_tensor"] = dense_masks_tensor + sample["selected_frame_index"] = selected_frame_index + if reserve_face_imgs is not None: + sample["reserve_face_imgs"] = reserve_face_imgs + if original_face_imgs is not None: + sample["original_face_imgs"] = original_face_imgs + else: + pixel_values, cap, data_type, video_dir = self.get_batch(idx) + sample["instance_prompt"] = self.id_token + cap + sample["instance_video"] = pixel_values + sample["video_path"] = video_dir + return sample + + # while True: + # sample = {} + # try: + # if self.is_train_face: + # pixel_values, cap, data_type, video_dir, expand_face_imgs, dense_masks_tensor, selected_frame_index, reserve_face_imgs, original_face_imgs = self.get_batch(idx) + # sample["instance_prompt"] = self.id_token + cap + # sample["instance_video"] = pixel_values + # sample["video_path"] = video_dir + # if self.is_train_face: + # sample["expand_face_imgs"] = expand_face_imgs + # sample["dense_masks_tensor"] = dense_masks_tensor + # sample["selected_frame_index"] = selected_frame_index + # if reserve_face_imgs is not None: + # sample["reserve_face_imgs"] = reserve_face_imgs + # if original_face_imgs is not None: + # sample["original_face_imgs"] = original_face_imgs + # else: + # pixel_values, cap, data_type, video_dir, = self.get_batch(idx) + # sample["instance_prompt"] = self.id_token + cap + # sample["instance_video"] = pixel_values + # sample["video_path"] = video_dir + # break + # except Exception as e: + # error_message = str(e) + # video_path = self.instance_video_paths[idx % len(self.instance_video_paths)] + # print(error_message, video_path) + # log_error_to_file(error_message, video_path) + # idx = random.randint(0, self.num_instance_videos - 1) + # return sample \ No newline at end of file diff --git a/util/deepspeed_configs/accelerate_config_machine_multi.yaml b/util/deepspeed_configs/accelerate_config_machine_multi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f6d49809cffbe8e48cea2acd928bf45a4aeebc30 --- /dev/null +++ b/util/deepspeed_configs/accelerate_config_machine_multi.yaml @@ -0,0 +1,18 @@ +compute_environment: LOCAL_MACHINE +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_config_file: util/deepspeed_configs/zero_stage2_config.json + deepspeed_hostfile: util/deepspeed_configs/hostfile.txt +fsdp_config: {} +machine_rank: 0 +main_process_ip: 100.64.24.6 +main_process_port: 12343 +main_training_function: main +num_machines: 2 +num_processes: 16 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/util/deepspeed_configs/accelerate_config_machine_single.yaml b/util/deepspeed_configs/accelerate_config_machine_single.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e45c6d1f53b70b14a14b6c777fb04f9a308d300 --- /dev/null +++ b/util/deepspeed_configs/accelerate_config_machine_single.yaml @@ -0,0 +1,13 @@ +compute_environment: LOCAL_MACHINE +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_config_file: util/deepspeed_configs/zero_stage2_config.json +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: 12345 +main_training_function: main +num_machines: 1 +num_processes: 8 +gpu_ids: 0,1,2,3,4,5,6,7 +use_cpu: false \ No newline at end of file diff --git a/util/deepspeed_configs/hostfile.txt b/util/deepspeed_configs/hostfile.txt new file mode 100644 index 0000000000000000000000000000000000000000..d2f4725adeac3c456b95c6411116e810f7ec835f --- /dev/null +++ b/util/deepspeed_configs/hostfile.txt @@ -0,0 +1,2 @@ +node-user@100.64.24.6 slots=8 +node-user@100.64.24.3 slots=8 \ No newline at end of file diff --git a/util/deepspeed_configs/zero_stage2_config.json b/util/deepspeed_configs/zero_stage2_config.json new file mode 100644 index 0000000000000000000000000000000000000000..4a544ed4385f7aea4ab15b1f058a08af9ffe42d0 --- /dev/null +++ b/util/deepspeed_configs/zero_stage2_config.json @@ -0,0 +1,17 @@ +{ + "bf16": { + "enabled": true + }, + "train_micro_batch_size_per_gpu": "auto", + "train_batch_size": "auto", + "gradient_clipping": 1.0, + "gradient_accumulation_steps": "auto", + "dump_state": true, + "zero_optimization": { + "stage": 2, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": 5e8 + } +} \ No newline at end of file diff --git a/util/rife/IFNet.py b/util/rife/IFNet.py new file mode 100644 index 0000000000000000000000000000000000000000..7b74fbf8cfaab947d8d18e0c81a5f40532541b05 --- /dev/null +++ b/util/rife/IFNet.py @@ -0,0 +1,123 @@ +from .refine import * + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1), + nn.PReLU(out_planes), + ) + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.PReLU(out_planes), + ) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c // 2, 3, 2, 1), + conv(c // 2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + ) + self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) + + def forward(self, x, flow, scale): + if scale != 1: + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) + if flow != None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale + x = torch.cat((x, flow), 1) + x = self.conv0(x) + x = self.convblock(x) + x + tmp = self.lastconv(x) + tmp = F.interpolate(tmp, scale_factor=scale * 2, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * scale * 2 + mask = tmp[:, 4:5] + return flow, mask + + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(6, c=240) + self.block1 = IFBlock(13 + 4, c=150) + self.block2 = IFBlock(13 + 4, c=90) + self.block_tea = IFBlock(16 + 4, c=90) + self.contextnet = Contextnet() + self.unet = Unet() + + def forward(self, x, scale=[4, 2, 1], timestep=0.5): + img0 = x[:, :3] + img1 = x[:, 3:6] + gt = x[:, 6:] # In inference time, gt is None + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + loss_distill = 0 + stu = [self.block0, self.block1, self.block2] + for i in range(3): + if flow != None: + flow_d, mask_d = stu[i]( + torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow, scale=scale[i] + ) + flow = flow + flow_d + mask = mask + mask_d + else: + flow, mask = stu[i](torch.cat((img0, img1), 1), None, scale=scale[i]) + mask_list.append(torch.sigmoid(mask)) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged_student = (warped_img0, warped_img1) + merged.append(merged_student) + if gt.shape[1] == 3: + flow_d, mask_d = self.block_tea( + torch.cat((img0, img1, warped_img0, warped_img1, mask, gt), 1), flow, scale=1 + ) + flow_teacher = flow + flow_d + warped_img0_teacher = warp(img0, flow_teacher[:, :2]) + warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) + mask_teacher = torch.sigmoid(mask + mask_d) + merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) + else: + flow_teacher = None + merged_teacher = None + for i in range(3): + merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) + if gt.shape[1] == 3: + loss_mask = ( + ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01) + .float() + .detach() + ) + loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean() + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[2] = torch.clamp(merged[2] + res, 0, 1) + return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill diff --git a/util/rife/IFNet_2R.py b/util/rife/IFNet_2R.py new file mode 100644 index 0000000000000000000000000000000000000000..0317b86e56c3eb793de0def592616b145ca54915 --- /dev/null +++ b/util/rife/IFNet_2R.py @@ -0,0 +1,123 @@ +from .refine_2R import * + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1), + nn.PReLU(out_planes), + ) + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.PReLU(out_planes), + ) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c // 2, 3, 1, 1), + conv(c // 2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + ) + self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) + + def forward(self, x, flow, scale): + if scale != 1: + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) + if flow != None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale + x = torch.cat((x, flow), 1) + x = self.conv0(x) + x = self.convblock(x) + x + tmp = self.lastconv(x) + tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + return flow, mask + + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(6, c=240) + self.block1 = IFBlock(13 + 4, c=150) + self.block2 = IFBlock(13 + 4, c=90) + self.block_tea = IFBlock(16 + 4, c=90) + self.contextnet = Contextnet() + self.unet = Unet() + + def forward(self, x, scale=[4, 2, 1], timestep=0.5): + img0 = x[:, :3] + img1 = x[:, 3:6] + gt = x[:, 6:] # In inference time, gt is None + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + loss_distill = 0 + stu = [self.block0, self.block1, self.block2] + for i in range(3): + if flow != None: + flow_d, mask_d = stu[i]( + torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow, scale=scale[i] + ) + flow = flow + flow_d + mask = mask + mask_d + else: + flow, mask = stu[i](torch.cat((img0, img1), 1), None, scale=scale[i]) + mask_list.append(torch.sigmoid(mask)) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged_student = (warped_img0, warped_img1) + merged.append(merged_student) + if gt.shape[1] == 3: + flow_d, mask_d = self.block_tea( + torch.cat((img0, img1, warped_img0, warped_img1, mask, gt), 1), flow, scale=1 + ) + flow_teacher = flow + flow_d + warped_img0_teacher = warp(img0, flow_teacher[:, :2]) + warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) + mask_teacher = torch.sigmoid(mask + mask_d) + merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) + else: + flow_teacher = None + merged_teacher = None + for i in range(3): + merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) + if gt.shape[1] == 3: + loss_mask = ( + ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01) + .float() + .detach() + ) + loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean() + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[2] = torch.clamp(merged[2] + res, 0, 1) + return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill diff --git a/util/rife/IFNet_HDv3.py b/util/rife/IFNet_HDv3.py new file mode 100644 index 0000000000000000000000000000000000000000..57f8003a21fc9315275d594afa714c44cdb8942e --- /dev/null +++ b/util/rife/IFNet_HDv3.py @@ -0,0 +1,138 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .warplayer import warp + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.PReLU(out_planes), + ) + + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_planes), + nn.PReLU(out_planes), + ) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c // 2, 3, 2, 1), + conv(c // 2, c, 3, 2, 1), + ) + self.convblock0 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock1 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock2 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock3 = nn.Sequential(conv(c, c), conv(c, c)) + self.conv1 = nn.Sequential( + nn.ConvTranspose2d(c, c // 2, 4, 2, 1), + nn.PReLU(c // 2), + nn.ConvTranspose2d(c // 2, 4, 4, 2, 1), + ) + self.conv2 = nn.Sequential( + nn.ConvTranspose2d(c, c // 2, 4, 2, 1), + nn.PReLU(c // 2), + nn.ConvTranspose2d(c // 2, 1, 4, 2, 1), + ) + + def forward(self, x, flow, scale=1): + x = F.interpolate( + x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False + ) + flow = ( + F.interpolate( + flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False + ) + * 1.0 + / scale + ) + feat = self.conv0(torch.cat((x, flow), 1)) + feat = self.convblock0(feat) + feat + feat = self.convblock1(feat) + feat + feat = self.convblock2(feat) + feat + feat = self.convblock3(feat) + feat + flow = self.conv1(feat) + mask = self.conv2(feat) + flow = ( + F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) + * scale + ) + mask = F.interpolate( + mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False + ) + return flow, mask + + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(7 + 4, c=90) + self.block1 = IFBlock(7 + 4, c=90) + self.block2 = IFBlock(7 + 4, c=90) + self.block_tea = IFBlock(10 + 4, c=90) + # self.contextnet = Contextnet() + # self.unet = Unet() + + def forward(self, x, scale_list=[4, 2, 1], training=False): + if training == False: + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = (x[:, :4]).detach() * 0 + mask = (x[:, :1]).detach() * 0 + loss_cons = 0 + block = [self.block0, self.block1, self.block2] + for i in range(3): + f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i]) + f1, m1 = block[i]( + torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), + torch.cat((flow[:, 2:4], flow[:, :2]), 1), + scale=scale_list[i], + ) + flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 + mask = mask + (m0 + (-m1)) / 2 + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + """ + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, 1:4] * 2 - 1 + """ + for i in range(3): + mask_list[i] = torch.sigmoid(mask_list[i]) + merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) + # merged[i] = torch.clamp(merged[i] + res, 0, 1) + return flow_list, mask_list[2], merged diff --git a/util/rife/IFNet_m.py b/util/rife/IFNet_m.py new file mode 100644 index 0000000000000000000000000000000000000000..b28acd3b802752d9230afca07fd3b52a52adffe4 --- /dev/null +++ b/util/rife/IFNet_m.py @@ -0,0 +1,127 @@ +from .refine import * + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1), + nn.PReLU(out_planes), + ) + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.PReLU(out_planes), + ) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c // 2, 3, 2, 1), + conv(c // 2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + ) + self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) + + def forward(self, x, flow, scale): + if scale != 1: + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) + if flow != None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale + x = torch.cat((x, flow), 1) + x = self.conv0(x) + x = self.convblock(x) + x + tmp = self.lastconv(x) + tmp = F.interpolate(tmp, scale_factor=scale * 2, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * scale * 2 + mask = tmp[:, 4:5] + return flow, mask + + +class IFNet_m(nn.Module): + def __init__(self): + super(IFNet_m, self).__init__() + self.block0 = IFBlock(6 + 1, c=240) + self.block1 = IFBlock(13 + 4 + 1, c=150) + self.block2 = IFBlock(13 + 4 + 1, c=90) + self.block_tea = IFBlock(16 + 4 + 1, c=90) + self.contextnet = Contextnet() + self.unet = Unet() + + def forward(self, x, scale=[4, 2, 1], timestep=0.5, returnflow=False): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + img0 = x[:, :3] + img1 = x[:, 3:6] + gt = x[:, 6:] # In inference time, gt is None + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + loss_distill = 0 + stu = [self.block0, self.block1, self.block2] + for i in range(3): + if flow != None: + flow_d, mask_d = stu[i]( + torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1), flow, scale=scale[i] + ) + flow = flow + flow_d + mask = mask + mask_d + else: + flow, mask = stu[i](torch.cat((img0, img1, timestep), 1), None, scale=scale[i]) + mask_list.append(torch.sigmoid(mask)) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged_student = (warped_img0, warped_img1) + merged.append(merged_student) + if gt.shape[1] == 3: + flow_d, mask_d = self.block_tea( + torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1), flow, scale=1 + ) + flow_teacher = flow + flow_d + warped_img0_teacher = warp(img0, flow_teacher[:, :2]) + warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) + mask_teacher = torch.sigmoid(mask + mask_d) + merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) + else: + flow_teacher = None + merged_teacher = None + for i in range(3): + merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) + if gt.shape[1] == 3: + loss_mask = ( + ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01) + .float() + .detach() + ) + loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean() + if returnflow: + return flow + else: + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[2] = torch.clamp(merged[2] + res, 0, 1) + return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill diff --git a/util/rife/RIFE.py b/util/rife/RIFE.py new file mode 100644 index 0000000000000000000000000000000000000000..a7039ef093c3b20f971107c7a4145b86a09782fa --- /dev/null +++ b/util/rife/RIFE.py @@ -0,0 +1,95 @@ +from torch.optim import AdamW +from torch.nn.parallel import DistributedDataParallel as DDP +from .IFNet import * +from .IFNet_m import * +from .loss import * +from .laplacian import * +from .refine import * + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Model: + def __init__(self, local_rank=-1, arbitrary=False): + if arbitrary == True: + self.flownet = IFNet_m() + else: + self.flownet = IFNet() + self.device() + self.optimG = AdamW( + self.flownet.parameters(), lr=1e-6, weight_decay=1e-3 + ) # use large weight decay may avoid NaN loss + self.epe = EPE() + self.lap = LapLoss() + self.sobel = SOBEL() + if local_rank != -1: + self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) + + def train(self): + self.flownet.train() + + def eval(self): + self.flownet.eval() + + def device(self): + self.flownet.to(device) + + def load_model(self, path, rank=0): + def convert(param): + return {k.replace("module.", ""): v for k, v in param.items() if "module." in k} + + if rank <= 0: + self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path)))) + + def save_model(self, path, rank=0): + if rank == 0: + torch.save(self.flownet.state_dict(), "{}/flownet.pkl".format(path)) + + def inference(self, img0, img1, scale=1, scale_list=[4, 2, 1], TTA=False, timestep=0.5): + for i in range(3): + scale_list[i] = scale_list[i] * 1.0 / scale + imgs = torch.cat((img0, img1), 1) + flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet( + imgs, scale_list, timestep=timestep + ) + if TTA == False: + return merged[2] + else: + flow2, mask2, merged2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet( + imgs.flip(2).flip(3), scale_list, timestep=timestep + ) + return (merged[2] + merged2[2].flip(2).flip(3)) / 2 + + def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): + for param_group in self.optimG.param_groups: + param_group["lr"] = learning_rate + img0 = imgs[:, :3] + img1 = imgs[:, 3:] + if training: + self.train() + else: + self.eval() + flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet( + torch.cat((imgs, gt), 1), scale=[4, 2, 1] + ) + loss_l1 = (self.lap(merged[2], gt)).mean() + loss_tea = (self.lap(merged_teacher, gt)).mean() + if training: + self.optimG.zero_grad() + loss_G = ( + loss_l1 + loss_tea + loss_distill * 0.01 + ) # when training RIFEm, the weight of loss_distill should be 0.005 or 0.002 + loss_G.backward() + self.optimG.step() + else: + flow_teacher = flow[2] + return merged[2], { + "merged_tea": merged_teacher, + "mask": mask, + "mask_tea": mask, + "flow": flow[2][:, :2], + "flow_tea": flow_teacher, + "loss_l1": loss_l1, + "loss_tea": loss_tea, + "loss_distill": loss_distill, + } diff --git a/util/rife/RIFE_HDv3.py b/util/rife/RIFE_HDv3.py new file mode 100644 index 0000000000000000000000000000000000000000..6123e31f528c6dce58efc9c1b488c7d740646e77 --- /dev/null +++ b/util/rife/RIFE_HDv3.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.optim import AdamW +import torch.optim as optim +import itertools +from .warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +from .IFNet_HDv3 import * +import torch.nn.functional as F +from .loss import * + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Model: + def __init__(self, local_rank=-1): + self.flownet = IFNet() + self.device() + self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4) + self.epe = EPE() + # self.vgg = VGGPerceptualLoss().to(device) + self.sobel = SOBEL() + if local_rank != -1: + self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) + + def train(self): + self.flownet.train() + + def eval(self): + self.flownet.eval() + + def device(self): + self.flownet.to(device) + + def load_model(self, path, rank=0): + def convert(param): + if rank == -1: + return {k.replace("module.", ""): v for k, v in param.items() if "module." in k} + else: + return param + + if rank <= 0: + if torch.cuda.is_available(): + self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path)))) + else: + self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu"))) + + def save_model(self, path, rank=0): + if rank == 0: + torch.save(self.flownet.state_dict(), "{}/flownet.pkl".format(path)) + + def inference(self, img0, img1, scale=1.0): + imgs = torch.cat((img0, img1), 1) + scale_list = [4 / scale, 2 / scale, 1 / scale] + flow, mask, merged = self.flownet(imgs, scale_list) + return merged[2] + + def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): + for param_group in self.optimG.param_groups: + param_group["lr"] = learning_rate + img0 = imgs[:, :3] + img1 = imgs[:, 3:] + if training: + self.train() + else: + self.eval() + scale = [4, 2, 1] + flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training) + loss_l1 = (merged[2] - gt).abs().mean() + loss_smooth = self.sobel(flow[2], flow[2] * 0).mean() + # loss_vgg = self.vgg(merged[2], gt) + if training: + self.optimG.zero_grad() + loss_G = loss_cons + loss_smooth * 0.1 + loss_G.backward() + self.optimG.step() + else: + flow_teacher = flow[2] + return merged[2], { + "mask": mask, + "flow": flow[2][:, :2], + "loss_l1": loss_l1, + "loss_cons": loss_cons, + "loss_smooth": loss_smooth, + } diff --git a/util/rife/__init__.py b/util/rife/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/util/rife/laplacian.py b/util/rife/laplacian.py new file mode 100644 index 0000000000000000000000000000000000000000..6e72e517fb954e983cc25209a05ba8d8b3f9b49d --- /dev/null +++ b/util/rife/laplacian.py @@ -0,0 +1,69 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +import torch + + +def gauss_kernel(size=5, channels=3): + kernel = torch.tensor( + [ + [1.0, 4.0, 6.0, 4.0, 1], + [4.0, 16.0, 24.0, 16.0, 4.0], + [6.0, 24.0, 36.0, 24.0, 6.0], + [4.0, 16.0, 24.0, 16.0, 4.0], + [1.0, 4.0, 6.0, 4.0, 1.0], + ] + ) + kernel /= 256.0 + kernel = kernel.repeat(channels, 1, 1, 1) + kernel = kernel.to(device) + return kernel + + +def downsample(x): + return x[:, :, ::2, ::2] + + +def upsample(x): + cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3) + cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3]) + cc = cc.permute(0, 1, 3, 2) + cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3) + cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2) + x_up = cc.permute(0, 1, 3, 2) + return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1])) + + +def conv_gauss(img, kernel): + img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode="reflect") + out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) + return out + + +def laplacian_pyramid(img, kernel, max_levels=3): + current = img + pyr = [] + for level in range(max_levels): + filtered = conv_gauss(current, kernel) + down = downsample(filtered) + up = upsample(down) + diff = current - up + pyr.append(diff) + current = down + return pyr + + +class LapLoss(torch.nn.Module): + def __init__(self, max_levels=5, channels=3): + super(LapLoss, self).__init__() + self.max_levels = max_levels + self.gauss_kernel = gauss_kernel(channels=channels) + + def forward(self, input, target): + pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels) + pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels) + return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target)) diff --git a/util/rife/loss.py b/util/rife/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed7564ba1e3d98e2d3e4f6a461aa4cddf4150bc --- /dev/null +++ b/util/rife/loss.py @@ -0,0 +1,130 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class EPE(nn.Module): + def __init__(self): + super(EPE, self).__init__() + + def forward(self, flow, gt, loss_mask): + loss_map = (flow - gt.detach()) ** 2 + loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5 + return loss_map * loss_mask + + +class Ternary(nn.Module): + def __init__(self): + super(Ternary, self).__init__() + patch_size = 7 + out_channels = patch_size * patch_size + self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels)) + self.w = np.transpose(self.w, (3, 2, 0, 1)) + self.w = torch.tensor(self.w).float().to(device) + + def transform(self, img): + patches = F.conv2d(img, self.w, padding=3, bias=None) + transf = patches - img + transf_norm = transf / torch.sqrt(0.81 + transf**2) + return transf_norm + + def rgb2gray(self, rgb): + r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + return gray + + def hamming(self, t1, t2): + dist = (t1 - t2) ** 2 + dist_norm = torch.mean(dist / (0.1 + dist), 1, True) + return dist_norm + + def valid_mask(self, t, padding): + n, _, h, w = t.size() + inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) + mask = F.pad(inner, [padding] * 4) + return mask + + def forward(self, img0, img1): + img0 = self.transform(self.rgb2gray(img0)) + img1 = self.transform(self.rgb2gray(img1)) + return self.hamming(img0, img1) * self.valid_mask(img0, 1) + + +class SOBEL(nn.Module): + def __init__(self): + super(SOBEL, self).__init__() + self.kernelX = torch.tensor( + [ + [1, 0, -1], + [2, 0, -2], + [1, 0, -1], + ] + ).float() + self.kernelY = self.kernelX.clone().T + self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device) + self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device) + + def forward(self, pred, gt): + N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3] + img_stack = torch.cat([pred.reshape(N * C, 1, H, W), gt.reshape(N * C, 1, H, W)], 0) + sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1) + sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1) + pred_X, gt_X = sobel_stack_x[: N * C], sobel_stack_x[N * C :] + pred_Y, gt_Y = sobel_stack_y[: N * C], sobel_stack_y[N * C :] + + L1X, L1Y = torch.abs(pred_X - gt_X), torch.abs(pred_Y - gt_Y) + loss = L1X + L1Y + return loss + + +class MeanShift(nn.Conv2d): + def __init__(self, data_mean, data_std, data_range=1, norm=True): + c = len(data_mean) + super(MeanShift, self).__init__(c, c, kernel_size=1) + std = torch.Tensor(data_std) + self.weight.data = torch.eye(c).view(c, c, 1, 1) + if norm: + self.weight.data.div_(std.view(c, 1, 1, 1)) + self.bias.data = -1 * data_range * torch.Tensor(data_mean) + self.bias.data.div_(std) + else: + self.weight.data.mul_(std.view(c, 1, 1, 1)) + self.bias.data = data_range * torch.Tensor(data_mean) + self.requires_grad = False + + +class VGGPerceptualLoss(torch.nn.Module): + def __init__(self, rank=0): + super(VGGPerceptualLoss, self).__init__() + blocks = [] + pretrained = True + self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features + self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X, Y, indices=None): + X = self.normalize(X) + Y = self.normalize(Y) + indices = [2, 7, 12, 21, 30] + weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10 / 1.5] + k = 0 + loss = 0 + for i in range(indices[-1]): + X = self.vgg_pretrained_features[i](X) + Y = self.vgg_pretrained_features[i](Y) + if (i + 1) in indices: + loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1 + k += 1 + return loss + + +if __name__ == "__main__": + img0 = torch.zeros(3, 3, 256, 256).float().to(device) + img1 = torch.tensor(np.random.normal(0, 1, (3, 3, 256, 256))).float().to(device) + ternary_loss = Ternary() + print(ternary_loss(img0, img1).shape) diff --git a/util/rife/pytorch_msssim/__init__.py b/util/rife/pytorch_msssim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e2baafc993ba96c978335553612f036eec03985 --- /dev/null +++ b/util/rife/pytorch_msssim/__init__.py @@ -0,0 +1,203 @@ +import torch +import torch.nn.functional as F +from math import exp +import numpy as np + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device) + window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() + return window + + +def create_window_3d(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()) + _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) + window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) + return window + + +def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, channel, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window(real_size, channel=channel).to(img1.device) + + # mu1 = F.conv2d(img1, window, padding=padd, groups=channel) + # mu2 = F.conv2d(img2, window, padding=padd, groups=channel) + mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel) + mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu2_sq + sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, _, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window_3d(real_size, channel=1).to(img1.device, dtype=img1.dtype) + # Channel is set to 1 since we consider color images as volumetric images + + img1 = img1.unsqueeze(1) + img2 = img2.unsqueeze(1) + + mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1) + mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_sq + sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu2_sq + sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): + device = img1.device + weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) + levels = weights.size()[0] + mssim = [] + mcs = [] + for _ in range(levels): + sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) + mssim.append(sim) + mcs.append(cs) + + img1 = F.avg_pool2d(img1, (2, 2)) + img2 = F.avg_pool2d(img2, (2, 2)) + + mssim = torch.stack(mssim) + mcs = torch.stack(mcs) + + # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) + if normalize: + mssim = (mssim + 1) / 2 + mcs = (mcs + 1) / 2 + + pow1 = mcs**weights + pow2 = mssim**weights + # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ + output = torch.prod(pow1[:-1] * pow2[-1]) + return output + + +# Classes to re-use window +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, val_range=None): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.val_range = val_range + + # Assume 3 channel for SSIM + self.channel = 3 + self.window = create_window(window_size, channel=self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.dtype == img1.dtype: + window = self.window + else: + window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) + self.window = window + self.channel = channel + + _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) + dssim = (1 - _ssim) / 2 + return dssim + + +class MSSSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, channel=3): + super(MSSSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = channel + + def forward(self, img1, img2): + return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) diff --git a/util/rife/refine.py b/util/rife/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..2f9becbe4b897a45e5e25d77b1cea0f0a657572f --- /dev/null +++ b/util/rife/refine.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +from .warplayer import warp +import torch.nn.functional as F + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.PReLU(out_planes), + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d( + in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True + ), + nn.PReLU(out_planes), + ) + + +class Conv2(nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +c = 16 + + +class Contextnet(nn.Module): + def __init__(self): + super(Contextnet, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x, flow): + x = self.conv1(x) + flow = ( + F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + * 0.5 + ) + f1 = warp(x, flow) + x = self.conv2(x) + flow = ( + F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + * 0.5 + ) + f2 = warp(x, flow) + x = self.conv3(x) + flow = ( + F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + * 0.5 + ) + f3 = warp(x, flow) + x = self.conv4(x) + flow = ( + F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + * 0.5 + ) + f4 = warp(x, flow) + return [f1, f2, f3, f4] + + +class Unet(nn.Module): + def __init__(self): + super(Unet, self).__init__() + self.down0 = Conv2(17, 2 * c) + self.down1 = Conv2(4 * c, 4 * c) + self.down2 = Conv2(8 * c, 8 * c) + self.down3 = Conv2(16 * c, 16 * c) + self.up0 = deconv(32 * c, 8 * c) + self.up1 = deconv(16 * c, 4 * c) + self.up2 = deconv(8 * c, 2 * c) + self.up3 = deconv(4 * c, c) + self.conv = nn.Conv2d(c, 3, 3, 1, 1) + + def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): + s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) + s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) + s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) + s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) + x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) + x = self.up1(torch.cat((x, s2), 1)) + x = self.up2(torch.cat((x, s1), 1)) + x = self.up3(torch.cat((x, s0), 1)) + x = self.conv(x) + return torch.sigmoid(x) diff --git a/util/rife/refine_2R.py b/util/rife/refine_2R.py new file mode 100644 index 0000000000000000000000000000000000000000..c6cc2c02f1700397488e09cf5718263e945138cb --- /dev/null +++ b/util/rife/refine_2R.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +from .warplayer import warp +import torch.nn.functional as F + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.PReLU(out_planes), + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d( + in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True + ), + nn.PReLU(out_planes), + ) + + +class Conv2(nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +c = 16 + + +class Contextnet(nn.Module): + def __init__(self): + super(Contextnet, self).__init__() + self.conv1 = Conv2(3, c, 1) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x, flow): + x = self.conv1(x) + # flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 + f1 = warp(x, flow) + x = self.conv2(x) + flow = ( + F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + * 0.5 + ) + f2 = warp(x, flow) + x = self.conv3(x) + flow = ( + F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + * 0.5 + ) + f3 = warp(x, flow) + x = self.conv4(x) + flow = ( + F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + * 0.5 + ) + f4 = warp(x, flow) + return [f1, f2, f3, f4] + + +class Unet(nn.Module): + def __init__(self): + super(Unet, self).__init__() + self.down0 = Conv2(17, 2 * c, 1) + self.down1 = Conv2(4 * c, 4 * c) + self.down2 = Conv2(8 * c, 8 * c) + self.down3 = Conv2(16 * c, 16 * c) + self.up0 = deconv(32 * c, 8 * c) + self.up1 = deconv(16 * c, 4 * c) + self.up2 = deconv(8 * c, 2 * c) + self.up3 = deconv(4 * c, c) + self.conv = nn.Conv2d(c, 3, 3, 2, 1) + + def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): + s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) + s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) + s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) + s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) + x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) + x = self.up1(torch.cat((x, s2), 1)) + x = self.up2(torch.cat((x, s1), 1)) + x = self.up3(torch.cat((x, s0), 1)) + x = self.conv(x) + return torch.sigmoid(x) diff --git a/util/rife/warplayer.py b/util/rife/warplayer.py new file mode 100644 index 0000000000000000000000000000000000000000..ff796e897564961845ffb654f5006ae00c09f362 --- /dev/null +++ b/util/rife/warplayer.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device) + .view(1, 1, 1, tenFlow.shape[3]) + .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + ) + tenVertical = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device) + .view(1, 1, tenFlow.shape[2], 1) + .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + ) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample( + input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True + ) diff --git a/util/rife_model.py b/util/rife_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5ad85e15b2747aa176b5d2c62ea59953a6db5bb3 --- /dev/null +++ b/util/rife_model.py @@ -0,0 +1,133 @@ +import torch +from diffusers.image_processor import VaeImageProcessor +from torch.nn import functional as F +import cv2 +from util.utils import * +from util.rife.pytorch_msssim import ssim_matlab +import numpy as np +import logging +import skvideo.io +from util.rife.RIFE_HDv3 import Model + +logger = logging.getLogger(__name__) +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def pad_image(img, scale): + _, _, h, w = img.shape + tmp = max(32, int(32 / scale)) + ph = ((h - 1) // tmp + 1) * tmp + pw = ((w - 1) // tmp + 1) * tmp + padding = (0, 0, pw - w, ph - h) + return F.pad(img, padding) + + +def make_inference(model, I0, I1, upscale_amount, n): + middle = model.inference(I0, I1, upscale_amount) + if n == 1: + return [middle] + first_half = make_inference(model, I0, middle, upscale_amount, n=n // 2) + second_half = make_inference(model, middle, I1, upscale_amount, n=n // 2) + if n % 2: + return [*first_half, middle, *second_half] + else: + return [*first_half, *second_half] + + +@torch.inference_mode() +def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_device="cpu"): + print(f"samples dtype:{samples.dtype}") + print(f"samples shape:{samples.shape}") + output = [] + pbar = ProgressBar(samples.shape[0], desc="frame interpolating") + + for b in range(samples.shape[0]): + frame = samples[b : b + 1] + _, _, h, w = frame.shape + + I0 = samples[b : b + 1] + I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:] + + I0 = I0.to(torch.float) + I1 = I1.to(torch.float) + + I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False) + I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) + + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + + if ssim > 0.996: + I1 = make_inference(model, I0, I1, upscale_amount, 1) + I1 = I1[0] + frame = I1 + + tmp_output = [] + if ssim < 0.2: + for _ in range((2**exp) - 1): + tmp_output.append(I0) + else: + tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else [] + + frame = F.interpolate(frame, size=(h, w)) + output.append(frame.to(output_device)) + + for tmp_frame in tmp_output: + tmp_frame = F.interpolate(tmp_frame, size=(h, w)) + output.append(tmp_frame.to(output_device)) + + pbar.update(1) + + return output + + +def load_rife_model(model_path): + model = Model() + model.load_model(model_path, -1) + model.eval() + return model + + +# Create a generator that yields each frame, similar to cv2.VideoCapture +def frame_generator(video_capture): + while True: + ret, frame = video_capture.read() + if not ret: + break + yield frame + video_capture.release() + + +def rife_inference_with_path(model, video_path): + video_capture = cv2.VideoCapture(video_path) + tot_frame = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) + pt_frame_data = [] + pt_frame = skvideo.io.vreader(video_path) + for frame in pt_frame: + pt_frame_data.append( + torch.from_numpy(np.transpose(frame, (2, 0, 1))).to("cpu", non_blocking=True).float() / 255.0 + ) + + pt_frame = torch.from_numpy(np.stack(pt_frame_data)) + pt_frame = pt_frame.to(device) + pbar = ProgressBar(tot_frame, desc="RIFE inference") + frames = ssim_interpolation_rife(model, pt_frame) + pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) + image_np = VaeImageProcessor.pt_to_numpy(pt_image) # (to [49, 512, 480, 3]) + image_pil = VaeImageProcessor.numpy_to_pil(image_np) + video_path = save_video(image_pil, fps=16) + if pbar: + pbar.update(1) + return video_path + + +def rife_inference_with_latents(model, latents): + rife_results = [] + latents = latents.to(device) + for i in range(latents.size(0)): + # [f, c, w, h] + latent = latents[i] + frames = ssim_interpolation_rife(model, latent) + pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h]) + rife_results.append(pt_image) + + return torch.stack(rife_results) diff --git a/util/utils.py b/util/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..43654aac34de26aac1f0bb82017544c8f740d3b9 --- /dev/null +++ b/util/utils.py @@ -0,0 +1,699 @@ +import os +import math +import tqdm +import logging +import argparse +import itertools +import PIL.Image +import numpy as np +from PIL import Image +import safetensors.torch +from datetime import datetime +from typing import Union, List +from spandrel import ModelLoader + +import torch +import torch.nn.functional as F +from diffusers.utils import export_to_video + +logger = logging.getLogger(__file__) +def get_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for ConsisID.") + + # ConsisID information + parser.add_argument("--train_type", choices=['t2v', 'i2v'], help="t2v or i2v") + parser.add_argument("--is_train_face", action='store_true') + parser.add_argument("--is_diff_lr", action='store_true') + parser.add_argument("--is_train_lora", action='store_true') + parser.add_argument("--is_kps", action='store_true') + parser.add_argument("--is_shuffle_data", action='store_true') + parser.add_argument("--enable_mask_loss", action='store_true') + parser.add_argument("--is_single_face", action='store_true') + parser.add_argument("--is_cross_face", action='store_true') + parser.add_argument("--is_align_face", action='store_true') + parser.add_argument("--is_reserve_face", action='store_true') + parser.add_argument("--is_accelerator_state_dict", action='store_true') + parser.add_argument("--is_validation", action='store_true') + parser.add_argument("--config_path", type=str, default=None) + parser.add_argument("--mask_path", type=str, default=None) + parser.add_argument("--pretrained_weight", type=str, default=None) + parser.add_argument("--sample_stride", type=int, default=3, help=".") + parser.add_argument("--skip_frames_start_percent", type=float, default=0.0, help=".") + parser.add_argument("--skip_frames_end_percent", type=float, default=1.0, help=".") + parser.add_argument("--miss_tolerance", type=int, default=6) + parser.add_argument("--min_distance", type=int, default=3) + parser.add_argument("--min_frames", type=int, default=1) + parser.add_argument("--max_frames", type=int, default=5) + parser.add_argument("--LFE_num_tokens", type=int, default=32) + parser.add_argument("--LFE_output_dim", type=int, default=768) + parser.add_argument("--LFE_heads", type=int, default=12) + parser.add_argument("--cross_attn_interval", type=int, default=1) + + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + + # Model information + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + # Dataset information + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--video_column", + type=str, + default="video", + help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.", + ) + parser.add_argument( + "--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + + # Validation + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + ) + parser.add_argument( + "--validation_images", + type=str, + default=None, + help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", + ) + parser.add_argument( + "--validation_prompt_separator", + type=str, + default=":::", + help="String that separates multiple validation prompts", + ) + parser.add_argument( + "--num_validation_videos", + type=int, + default=1, + help="Number of videos that should be generated during validation per `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`." + ), + ) + parser.add_argument( + "--low_vram", action="store_true", help="Whether enable low_vram mode." + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=6, + help="The guidance scale to use while sampling validation videos.", + ) + parser.add_argument( + "--use_dynamic_cfg", + action="store_true", + default=False, + help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", + ) + + # Training information + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--rank", + type=int, + default=128, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=float, + default=128, + help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="cogvideox-i2v-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="All input videos are resized to this height.", + ) + parser.add_argument( + "--width", + type=int, + default=720, + help="All input videos are resized to this width.", + ) + parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") + parser.add_argument( + "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." + ) + parser.add_argument( + "--skip_frames_start", + type=int, + default=0, + help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", + ) + parser.add_argument( + "--skip_frames_end", + type=int, + default=0, + help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip videos horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=3e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="cosine_with_restarts", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) + parser.add_argument( + "--noised_image_dropout", + type=float, + default=0.05, + help="Image condition dropout probability.", + ) + + # Optimizer + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw", "prodigy"], + help=("The optimizer type to use."), + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", + ) + parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.") + parser.add_argument( + "--prodigy_safeguard_warmup", + action="store_true", + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", + ) + + # Other information + parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="Directory where logs are stored.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default=None, + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + '--trainable_modules', + nargs='+', + help='Enter a list of trainable modules' + ) + parser.add_argument("--nccl_timeout", type=int, default=600, help="NCCL backend timeout in seconds.") + + return parser.parse_args() + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + +def save_tensor_as_image(tensor, file_path): + """ + Saves a PyTorch tensor as an image file. + + Args: + tensor (torch.Tensor): The image tensor to save. + file_path (str): Path to save the image file. + """ + # Ensure the tensor is in CPU memory and detach it from the computation graph + tensor = tensor.cpu().detach() + + # Convert from PyTorch to NumPy format, and handle the scaling from [0, 1] to [0, 255] + tensor = tensor.squeeze() # Remove unnecessary dimensions if any + tensor = tensor.permute(1, 2, 0) # Change from (C, H, W) to (H, W, C) + tensor = tensor.numpy() * 255 # Scale from [0, 1] to [0, 255] + tensor = tensor.astype(np.uint8) # Convert to uint8 + + # Convert the NumPy array to a PIL Image and save it + image = Image.fromarray(tensor) + image.save(file_path) + +def pixel_values_to_pil(pixel_values, frame_index=0): + if pixel_values.is_cuda: + pixel_values = pixel_values.clone().cpu() + pixel_values = (pixel_values + 1.0) / 2.0 * 255.0 + pixel_values = pixel_values.clamp(0, 255).byte() + frame = pixel_values[frame_index] # [C, H, W] + frame = frame.permute(1, 2, 0) # [H, W, C] + frame_np = frame.numpy() + image = Image.fromarray(frame_np) + return image + +def load_torch_file(ckpt, device=None, dtype=torch.float16): + if device is None: + device = torch.device("cpu") + if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): + sd = safetensors.torch.load_file(ckpt, device=device.type) + else: + if not "weights_only" in torch.load.__code__.co_varnames: + logger.warning( + "Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely." + ) + + pl_sd = torch.load(ckpt, map_location=device, weights_only=True) + if "global_step" in pl_sd: + logger.debug(f"Global Step: {pl_sd['global_step']}") + if "state_dict" in pl_sd: + sd = pl_sd["state_dict"] + elif "params_ema" in pl_sd: + sd = pl_sd["params_ema"] + else: + sd = pl_sd + + sd = {k: v.to(dtype) for k, v in sd.items()} + return sd + + +def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): + if filter_keys: + out = {} + else: + out = state_dict + for rp in replace_prefix: + replace = list( + map( + lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp) :])), + filter(lambda a: a.startswith(rp), state_dict.keys()), + ) + ) + for x in replace: + w = state_dict.pop(x[0]) + out[x[1]] = w + return out + + +def module_size(module): + module_mem = 0 + sd = module.state_dict() + for k in sd: + t = sd[k] + module_mem += t.nelement() * t.element_size() + return module_mem + + +def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): + return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) + + +@torch.inference_mode() +def tiled_scale_multidim( + samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None +): + dims = len(tile) + print(f"samples dtype:{samples.dtype}") + output = torch.empty( + [samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), + device=output_device, + ) + + for b in range(samples.shape[0]): + s = samples[b : b + 1] + out = torch.zeros( + [s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), + device=output_device, + ) + out_div = torch.zeros( + [s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), + device=output_device, + ) + + for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))): + s_in = s + upscaled = [] + + for d in range(dims): + pos = max(0, min(s.shape[d + 2] - overlap, it[d])) + l = min(tile[d], s.shape[d + 2] - pos) + s_in = s_in.narrow(d + 2, pos, l) + upscaled.append(round(pos * upscale_amount)) + + ps = function(s_in).to(output_device) + mask = torch.ones_like(ps) + feather = round(overlap * upscale_amount) + for t in range(feather): + for d in range(2, dims + 2): + m = mask.narrow(d, t, 1) + m *= (1.0 / feather) * (t + 1) + m = mask.narrow(d, mask.shape[d] - 1 - t, 1) + m *= (1.0 / feather) * (t + 1) + + o = out + o_d = out_div + for d in range(dims): + o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + + o += ps * mask + o_d += mask + + if pbar is not None: + pbar.update(1) + + output[b : b + 1] = out / out_div + return output + + +def tiled_scale( + samples, + function, + tile_x=64, + tile_y=64, + overlap=8, + upscale_amount=4, + out_channels=3, + output_device="cpu", + pbar=None, +): + return tiled_scale_multidim( + samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar + ) + + +def load_sd_upscale(ckpt, inf_device): + sd = load_torch_file(ckpt, device=inf_device) + if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: + sd = state_dict_prefix_replace(sd, {"module.": ""}) + out = ModelLoader().load_from_state_dict(sd).half() + return out + + +def upscale(upscale_model, tensor: torch.Tensor, inf_device, output_device="cpu") -> torch.Tensor: + memory_required = module_size(upscale_model.model) + memory_required += ( + (512 * 512 * 3) * tensor.element_size() * max(upscale_model.scale, 1.0) * 384.0 + ) # The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate + memory_required += tensor.nelement() * tensor.element_size() + print(f"UPScaleMemory required: {memory_required / 1024 / 1024 / 1024} GB") + + upscale_model.to(inf_device) + tile = 512 + overlap = 32 + + steps = tensor.shape[0] * get_tiled_scale_steps( + tensor.shape[3], tensor.shape[2], tile_x=tile, tile_y=tile, overlap=overlap + ) + + pbar = ProgressBar(steps, desc="Tiling and Upscaling") + + s = tiled_scale( + samples=tensor.to(torch.float16), + function=lambda a: upscale_model(a), + tile_x=tile, + tile_y=tile, + overlap=overlap, + upscale_amount=upscale_model.scale, + pbar=pbar, + ) + + upscale_model.to(output_device) + return s + + +def upscale_batch_and_concatenate(upscale_model, latents, inf_device, output_device="cpu") -> torch.Tensor: + upscaled_latents = [] + for i in range(latents.size(0)): + latent = latents[i] + upscaled_latent = upscale(upscale_model, latent, inf_device, output_device) + upscaled_latents.append(upscaled_latent) + return torch.stack(upscaled_latents) + + +def save_video(tensor: Union[List[np.ndarray], List[PIL.Image.Image]], fps: int = 8): + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + video_path = f"./output/{timestamp}.mp4" + os.makedirs(os.path.dirname(video_path), exist_ok=True) + export_to_video(tensor, video_path, fps=fps) + return video_path + + +class ProgressBar: + def __init__(self, total, desc=None): + self.total = total + self.current = 0 + self.b_unit = tqdm.tqdm(total=total, desc="ProgressBar context index: 0" if desc is None else desc) + + def update(self, value): + if value > self.total: + value = self.total + self.current = value + if self.b_unit is not None: + self.b_unit.set_description("ProgressBar context index: {}".format(self.current)) + self.b_unit.refresh() + + self.b_unit.update(self.current) \ No newline at end of file