diff --git a/README.md b/README.md index 3c854f885762ef20890bb773fde87d769ea9b907..8043cb1c8033db9e737163808859e730bc4a17e6 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ --- -title: OLA VLM -emoji: πŸ’¬ -colorFrom: yellow +title: OLA-VLM +emoji: πŸ” +colorFrom: blue colorTo: purple sdk: gradio -sdk_version: 5.0.1 +sdk_version: 4.16.0 app_file: app.py pinned: false license: apache-2.0 diff --git a/app.py b/app.py index 0da0319a5b670dce5025888fde58916b96f19869..2ecc271832f7f04e69b6a0dfcf2e953db868ac15 100644 --- a/app.py +++ b/app.py @@ -1,64 +1,487 @@ import gradio as gr -from huggingface_hub import InferenceClient +import spaces +import torch +import numpy as np -""" -For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference -""" -client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") +from ola_vlm.constants import DEFAULT_IMAGE_TOKEN + +from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from ola_vlm.conversation import conv_templates, SeparatorStyle +from ola_vlm.model.builder import load_pretrained_model +from ola_vlm.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images + +from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers import DPMSolverMultistepScheduler +from transformers import OneFormerProcessor +from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead +from ola_vlm.ola_utils import visualize_oneformer_masks_on_image, oneformer_prepare_panoptic_instance_prediction +import matplotlib +from PIL import Image, ImageDraw, ImageFont +import argparse +import math + +from transformers import TextIteratorStreamer +from threading import Thread + +def make_grid(pil_images, layer_indices=None): + new_images = [] + new_captions = [] + + # Resize images and prepare captions + for i, pil_image in enumerate(pil_images): + pil_image = pil_image.resize((256, 256)) + new_images.append(pil_image) + if layer_indices is not None: + new_captions.append(f"Layer: {layer_indices[i]}") + else: + new_captions.append(f"Layer: {i+1}") + + images = new_images + captions = new_captions + + width, height = images[0].size + font_size = 18 + + # Calculate the number of rows and columns for the grid + images_per_row = min(len(images), 4) # Max 4 images per row + row_count = math.ceil(len(images) / images_per_row) + total_width = width * images_per_row + total_height = height * row_count + + # Create a new blank image + new_image = Image.new("RGB", (total_width, total_height), "white") + draw = ImageDraw.Draw(new_image) + + # Load a default font + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size) + except: + font = ImageFont.load_default() + + # Place images and captions in the grid + for i, (image, caption) in enumerate(zip(images, captions)): + row = i // images_per_row + col = i % images_per_row + x_offset = col * width + y_offset = row * height + + # Paste the image + new_image.paste(image, (x_offset, y_offset)) + + # Calculate text and background positions + text_width, text_height = draw.textsize(caption, font=font) + text_position = (x_offset + 10, y_offset + height - text_height - 10) + background_position = ( + text_position[0] - 5, + text_position[1] - 5, + text_position[0] + text_width + 5, + text_position[1] + text_height + 5, + ) + + # Draw background rectangle and text + draw.rectangle(background_position, fill="white", outline="black") + draw.text(text_position, caption, fill="black", font=font) + + return new_image + +def reload_from_ckpt(model_path, model, cache_dir=None): + import os + from safetensors import safe_open + from huggingface_hub import hf_hub_download, list_repo_files + + state_dict = {} + + # Check if the path is a local directory or HF Hub model + if os.path.isdir(model_path): + # Local directory: Load safetensors files + safetensors_paths = [os.path.join(model_path, f) for f in os.listdir(model_path) if f.endswith('.safetensors')] + else: + # HF Hub: Get list of safetensors files and download them + repo_files = list_repo_files(model_path) + safetensors_paths = [ + hf_hub_download(model_path, file_name, cache_dir=cache_dir) + for file_name in repo_files if file_name.endswith('.safetensors') + ] + + # Load safetensors files into the state_dict + for path in safetensors_paths: + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + + # Load the state dict into the model + model.load_state_dict(state_dict, strict=False) + return model + +# os.environ['GRADIO_TEMP_DIR'] = './gradio_tmp' +no_change_btn = gr.Button() +enable_btn = gr.Button(interactive=True) +disable_btn = gr.Button(interactive=False) + +argparser = argparse.ArgumentParser() +argparser.add_argument("--server_name", default="0.0.0.0", type=str) +argparser.add_argument("--port", default="6324", type=str) +argparser.add_argument("--model-path", default="shi-labs/pretrain_dsg_OLA-VLM-CLIP-ViT-Llama3-8b", type=str) +argparser.add_argument("--model-base", type=str, default=None) +argparser.add_argument("--num-gpus", type=int, default=1) +argparser.add_argument("--conv-mode", type=str, default="llava_llama_3") +argparser.add_argument("--temperature", type=float, default=0.2) +argparser.add_argument("--max-new-tokens", type=int, default=512) +argparser.add_argument("--num_frames", type=int, default=16) +argparser.add_argument("--load-8bit", action="store_true") +argparser.add_argument("--load-4bit", action="store_true") +argparser.add_argument("--debug", action="store_true") + +args = argparser.parse_args() +model_path = args.model_path +conv_mode = args.conv_mode +filt_invalid="cut" +model_name = get_model_name_from_path(args.model_path) +tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) +model = reload_from_ckpt("shi-labs/OLA-VLM-CLIP-ViT-Llama3-8b", model) +our_chatbot = None + +pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16") +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +oneformer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large") +oneformer = OneFormerHead.from_pretrained("shi-labs/oneformer_coco_swin_large").to("cuda") + +gen_layer_indices = model.config.image_gen["img_layer_indices"].split("-") +seg_layer_indices = model.config.image_seg["seg_layer_indices"].split("-") +depth_layer_indices = model.config.image_depth["depth_layer_indices"].split("-") + + +def clear_history(): + state =conv_templates[conv_mode].copy() + return (state, state.to_gradio_chatbot(), "", None, None, None, None) + (disable_btn,) * 5 + +def add_text(state, imagebox, textbox, image_process_mode): + if state is None: + state = conv_templates[conv_mode].copy() + + if imagebox is not None: + textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox + image = Image.open(imagebox).convert('RGB') + + if imagebox is not None: + textbox = (textbox, image, image_process_mode) + + state.append_message(state.roles[0], textbox) + state.append_message(state.roles[1], None) + + yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + +def get_gen_images(out): + img_embeds = out.image_embs + if len(img_embeds) == 0: + return None + images = [] + for img_embed in img_embeds: + gen_image = pipe(image_embeds=img_embed.squeeze(1), + num_inference_steps=25, + ).images[0] + images.append(gen_image) + grid_image = make_grid(images, gen_layer_indices) + return grid_image + +def get_depth_images(out, org_size): + depth_preds = out.depth_preds + if len(depth_preds) == 0: + return None + depths = [] -def respond( - message, - history: list[tuple[str, str]], - system_message, - max_tokens, - temperature, - top_p, -): - messages = [{"role": "system", "content": system_message}] + for i, depth_pred in enumerate(depth_preds): + depth = (depth_pred - depth_pred.min()) / (depth_pred.max() - depth_pred.min()) * 255.0 + depth = depth.squeeze(0).cpu().numpy() + depth = depth.astype(np.uint8) + cmap = matplotlib.colormaps.get_cmap('Spectral_r') + depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8) + depth = Image.fromarray(depth) + depth = depth.resize(org_size) + depths.append(depth) + grid_image = make_grid(depths, depth_layer_indices) + return grid_image - for val in history: - if val[0]: - messages.append({"role": "user", "content": val[0]}) - if val[1]: - messages.append({"role": "assistant", "content": val[1]}) +def get_seg_images(out, image): + seg_embs = out.seg_embs + + if len(seg_embs) == 0: + return None + + seg_preds = [] + inputs = oneformer_processor(image, ["semantic"], return_tensors="pt") + inputs["pixel_values"] = inputs["pixel_values"].to(out.logits.device, out.logits.dtype) + inputs["task_inputs"] = inputs["task_inputs"].to(out.logits.device, out.logits.dtype) + backbone_features = oneformer.get_backbone_feats(**inputs) + for i, seg_emb in enumerate(seg_embs): + pred = oneformer.get_masks(**inputs, backbone_last_feature=seg_emb.float(), all_backbone_features=backbone_features) + pred = oneformer_processor.post_process_panoptic_segmentation( + pred, target_sizes=[image.size[::-1]] + )[0] + pred_msk, pred_cls = oneformer_prepare_panoptic_instance_prediction(**pred, oneformer=oneformer) + pred = visualize_oneformer_masks_on_image(image, pred_msk, pred_cls) + seg_preds.append(pred) + grid_image = make_grid(seg_preds, seg_layer_indices) + return grid_image - messages.append({"role": "user", "content": message}) +def delete_text(state, image_process_mode): + state.messages[-1][-1] = None + prev_human_msg = state.messages[-2] + if type(prev_human_msg[1]) in (tuple, list): + prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) + yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) - response = "" +def regenerate(state, image_process_mode): + state.messages[-1][-1] = None + prev_human_msg = state.messages[-2] + if type(prev_human_msg[1]) in (tuple, list): + prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) + state.skip_next = False + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 - for message in client.chat_completion( - messages, - max_tokens=max_tokens, - stream=True, +@spaces.GPU +def get_interm_outs(state): + prompt = state.get_prompt() + images = state.get_images(return_pil=True) + #prompt, image_args = process_image(prompt, images) + + if images is not None and len(images) > 0: + if len(images) > 0: + if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): + raise ValueError("Number of images does not match number of tokens in prompt") + + #images = [load_image_from_base64(image) for image in images] + image_sizes = [image.size for image in images] + inp_images = process_images(images, image_processor, model.config) + + if type(inp_images) is list: + inp_images = [image.to(model.device, dtype=torch.float16) for image in images] + else: + inp_images = inp_images.to(model.device, dtype=torch.float16) + else: + inp_images = None + image_sizes = None + image_args = {"images": inp_images, "image_sizes": image_sizes} + else: + inp_images = None + image_args = {} + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + + interm_outs = model.get_visual_interpretations( + input_ids, + **image_args + ) + + depth_outs = get_depth_images(interm_outs, image_sizes[0]) + seg_outs = get_seg_images(interm_outs, images[0]) + gen_outs = get_gen_images(interm_outs) + + return depth_outs, seg_outs, gen_outs + +@spaces.GPU +def generate(state, temperature, top_p, max_output_tokens): + prompt = state.get_prompt() + images = state.get_images(return_pil=True) + #prompt, image_args = process_image(prompt, images) + + ori_prompt = prompt + num_image_tokens = 0 + + if images is not None and len(images) > 0: + if len(images) > 0: + if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): + raise ValueError("Number of images does not match number of tokens in prompt") + + #images = [load_image_from_base64(image) for image in images] + image_sizes = [image.size for image in images] + images = process_images(images, image_processor, model.config) + + if type(images) is list: + images = [image.to(model.device, dtype=torch.float16) for image in images] + else: + images = images.to(model.device, dtype=torch.float16) + else: + images = None + image_sizes = None + image_args = {"images": images, "image_sizes": image_sizes} + else: + images = None + image_args = {} + + max_context_length = getattr(model.config, 'max_position_embeddings', 2048) + max_new_tokens = max_output_tokens + do_sample = True if temperature > 0.001 else False + stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2 + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) + + max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) + + if max_new_tokens < 1: + return + + thread = Thread(target=model.generate, kwargs=dict( + inputs=input_ids, + do_sample=do_sample, temperature=temperature, top_p=top_p, - ): - token = message.choices[0].delta.content + max_new_tokens=max_new_tokens, + streamer=streamer, + use_cache=True, + pad_token_id=tokenizer.eos_token_id, + **image_args + )) + thread.start() + generated_text = '' + for new_text in streamer: + generated_text += new_text + if generated_text.endswith(stop_str): + generated_text = generated_text[:-len(stop_str)] + state.messages[-1][-1] = generated_text + yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + + yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5 + + torch.cuda.empty_cache() + +txt = gr.Textbox( + scale=4, + show_label=False, + placeholder="Enter text and press enter.", + container=False, +) - response += token - yield response +title = "

OLA-VLM: Optimizing Language Model Representations for Enhanced Visual Quality and Alignment

" +description = "

Jitesh Jain    Zhengyuan Yang    Humphrey Shi*    Jianfeng Gao*    Jianwei Yang*

" \ + + "

*Equal Advising

" \ + + "

Project Page | Video | ArXiv | Github

" +tos_markdown = (""" +### Terms of use +By using this service, users are required to agree to the following terms: +The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. +""") + + +learn_more_markdown = (""" +### License +The service is a research preview intended for non-commercial use only, subject to the [License](https://huggingface.co/lmsys/vicuna-7b-v1.5) of Vicuna-v1.5, [License](https://github.com/haotian-liu/LLaVA/blob/main/LICENSE) of LLaVA, [Terms of Use](https://cocodataset.org/#termsofuse) of the COCO dataset, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. +""") + +block_css = """ +#buttons button { + min-width: min(120px,100%); +} """ -For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface -""" -demo = gr.ChatInterface( - respond, - additional_inputs=[ - gr.Textbox(value="You are a friendly Chatbot.", label="System message"), - gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), - gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), - gr.Slider( - minimum=0.1, - maximum=1.0, - value=0.95, - step=0.05, - label="Top-p (nucleus sampling)", - ), - ], -) -if __name__ == "__main__": - demo.launch() +textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) +with gr.Blocks(title="OLA-VLM", theme=gr.themes.Default(), css=block_css) as demo: + state = gr.State() + + gr.Markdown(title) + gr.Markdown(description) + + with gr.Row(): + with gr.Column(scale=4): + imagebox = gr.Image(label="Input Image", type="filepath") + image_process_mode = gr.Radio( + ["Crop", "Resize", "Pad", "Default"], + value="Default", + label="Preprocess for non-square image", visible=False) + + # with gr.Accordion("Parameters", open=False) as parameter_row: + with gr.Row(): + temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) + top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) + max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) + + with gr.Column(scale=8): + chatbot = gr.Chatbot( + elem_id="chatbot", + label="OLA-VLM", + height=300, + layout="panel", + ) + textbox.render() + with gr.Row(elem_id="buttons") as button_row: + upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False, visible=False) + downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False, visible=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False, visible=False) + #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) + regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False) + clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False) + submit_btn = gr.Button(value="Send", variant="primary") + + with gr.Accordion("Representations from selected layers of the LLM (expects only a single image input)", open=False) as interm_out: + inter_vis_btn = gr.Button(value="✨ Visualize") + with gr.Row(): + depth_box = gr.Image(label="depth", type="pil", visible=True) + seg_box = gr.Image(label="seg", type="pil", visible=True) + gen_box = gr.Image(label="gen", type="pil", visible=True) + + gr.Examples(examples=[ + [f"assets/cars.jpg", "Which car is in front: the blue or the brown one?"], + [f"assets/pb.jpg", "Where is the bulding located with respect to the man?"], + ], inputs=[imagebox, textbox], cache_examples=False) + + # gr.Markdown(tos_markdown) + # gr.Markdown(learn_more_markdown) + # url_params = gr.JSON(visible=False) + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + + inter_vis_btn.click( + get_interm_outs, + [state], + [depth_box, seg_box, gen_box], + ) + + clear_btn.click( + clear_history, + None, + [state, chatbot, textbox, imagebox, depth_box, gen_box, seg_box] + btn_list, + queue=False + ) + + regenerate_btn.click( + delete_text, + [state, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + generate, + [state, temperature, top_p, max_output_tokens], + [state, chatbot, textbox, imagebox] + btn_list, + ) + textbox.submit( + add_text, + [state, imagebox, textbox, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + generate, + [state, temperature, top_p, max_output_tokens], + [state, chatbot, textbox, imagebox] + btn_list, + ) + + submit_btn.click( + add_text, + [state, imagebox, textbox, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + generate, + [state, temperature, top_p, max_output_tokens], + [state, chatbot, textbox, imagebox] + btn_list, + ) + +demo.queue( + status_update_rate=10, + api_open=False +).launch(share=False) +demo.queue() \ No newline at end of file diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2c40fe5026538930746519aa0d600531ea73fb --- /dev/null +++ b/demo.py @@ -0,0 +1,486 @@ +import gradio as gr +import os +import torch +import numpy as np + +from ola_vlm.constants import DEFAULT_IMAGE_TOKEN + +from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from ola_vlm.conversation import conv_templates, SeparatorStyle +from ola_vlm.model.builder import load_pretrained_model +from ola_vlm.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images + +from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers import DPMSolverMultistepScheduler +from transformers import OneFormerProcessor +from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead +from ola_vlm.ola_utils import visualize_oneformer_masks_on_image, oneformer_prepare_panoptic_instance_prediction +import matplotlib +from PIL import Image, ImageDraw, ImageFont +import argparse +import math + +from transformers import TextIteratorStreamer +from threading import Thread + +def make_grid(pil_images, layer_indices=None): + new_images = [] + new_captions = [] + + # Resize images and prepare captions + for i, pil_image in enumerate(pil_images): + pil_image = pil_image.resize((256, 256)) + new_images.append(pil_image) + if layer_indices is not None: + new_captions.append(f"Layer: {layer_indices[i]}") + else: + new_captions.append(f"Layer: {i+1}") + + images = new_images + captions = new_captions + + width, height = images[0].size + font_size = 18 + + # Calculate the number of rows and columns for the grid + images_per_row = min(len(images), 4) # Max 4 images per row + row_count = math.ceil(len(images) / images_per_row) + total_width = width * images_per_row + total_height = height * row_count + + # Create a new blank image + new_image = Image.new("RGB", (total_width, total_height), "white") + draw = ImageDraw.Draw(new_image) + + # Load a default font + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size) + except: + font = ImageFont.load_default() + + # Place images and captions in the grid + for i, (image, caption) in enumerate(zip(images, captions)): + row = i // images_per_row + col = i % images_per_row + x_offset = col * width + y_offset = row * height + + # Paste the image + new_image.paste(image, (x_offset, y_offset)) + + # Calculate text and background positions + text_width, text_height = draw.textsize(caption, font=font) + text_position = (x_offset + 10, y_offset + height - text_height - 10) + background_position = ( + text_position[0] - 5, + text_position[1] - 5, + text_position[0] + text_width + 5, + text_position[1] + text_height + 5, + ) + + # Draw background rectangle and text + draw.rectangle(background_position, fill="white", outline="black") + draw.text(text_position, caption, fill="black", font=font) + + return new_image + +def reload_from_ckpt(model_path, model, cache_dir=None): + import os + from safetensors import safe_open + from huggingface_hub import hf_hub_download, list_repo_files + + state_dict = {} + + # Check if the path is a local directory or HF Hub model + if os.path.isdir(model_path): + # Local directory: Load safetensors files + safetensors_paths = [os.path.join(model_path, f) for f in os.listdir(model_path) if f.endswith('.safetensors')] + else: + # HF Hub: Get list of safetensors files and download them + repo_files = list_repo_files(model_path) + safetensors_paths = [ + hf_hub_download(model_path, file_name, cache_dir=cache_dir) + for file_name in repo_files if file_name.endswith('.safetensors') + ] + + # Load safetensors files into the state_dict + for path in safetensors_paths: + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + + # Load the state dict into the model + model.load_state_dict(state_dict, strict=False) + return model + +# os.environ['GRADIO_TEMP_DIR'] = './gradio_tmp' +no_change_btn = gr.Button() +enable_btn = gr.Button(interactive=True) +disable_btn = gr.Button(interactive=False) + +argparser = argparse.ArgumentParser() +argparser.add_argument("--server_name", default="0.0.0.0", type=str) +argparser.add_argument("--port", default="6324", type=str) +argparser.add_argument("--model-path", default="shi-labs/pretrain_dsg_OLA-VLM-CLIP-ViT-Llama3-8b", type=str) +argparser.add_argument("--model-base", type=str, default=None) +argparser.add_argument("--num-gpus", type=int, default=1) +argparser.add_argument("--conv-mode", type=str, default="llava_llama_3") +argparser.add_argument("--temperature", type=float, default=0.2) +argparser.add_argument("--max-new-tokens", type=int, default=512) +argparser.add_argument("--num_frames", type=int, default=16) +argparser.add_argument("--load-8bit", action="store_true") +argparser.add_argument("--load-4bit", action="store_true") +argparser.add_argument("--debug", action="store_true") + +args = argparser.parse_args() +model_path = args.model_path +conv_mode = args.conv_mode +filt_invalid="cut" +model_name = get_model_name_from_path(args.model_path) +tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) +model = reload_from_ckpt("shi-labs/OLA-VLM-CLIP-ViT-Llama3-8b", model) +our_chatbot = None + +pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16") +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +oneformer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large") +oneformer = OneFormerHead.from_pretrained("shi-labs/oneformer_coco_swin_large").to("cuda") + +gen_layer_indices = model.config.image_gen["img_layer_indices"].split("-") +seg_layer_indices = model.config.image_seg["seg_layer_indices"].split("-") +depth_layer_indices = model.config.image_depth["depth_layer_indices"].split("-") + + +def clear_history(): + state =conv_templates[conv_mode].copy() + return (state, state.to_gradio_chatbot(), "", None, None, None, None) + (disable_btn,) * 5 + +def add_text(state, imagebox, textbox, image_process_mode): + if state is None: + state = conv_templates[conv_mode].copy() + + if imagebox is not None: + textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox + image = Image.open(imagebox).convert('RGB') + + if imagebox is not None: + textbox = (textbox, image, image_process_mode) + + state.append_message(state.roles[0], textbox) + state.append_message(state.roles[1], None) + + yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + +def get_gen_images(out): + img_embeds = out.image_embs + if len(img_embeds) == 0: + return None + images = [] + for img_embed in img_embeds: + gen_image = pipe(image_embeds=img_embed.squeeze(1), + num_inference_steps=25, + ).images[0] + images.append(gen_image) + grid_image = make_grid(images, gen_layer_indices) + return grid_image + +def get_depth_images(out, org_size): + depth_preds = out.depth_preds + + if len(depth_preds) == 0: + return None + depths = [] + + for i, depth_pred in enumerate(depth_preds): + depth = (depth_pred - depth_pred.min()) / (depth_pred.max() - depth_pred.min()) * 255.0 + depth = depth.squeeze(0).cpu().numpy() + depth = depth.astype(np.uint8) + cmap = matplotlib.colormaps.get_cmap('Spectral_r') + depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8) + depth = Image.fromarray(depth) + depth = depth.resize(org_size) + depths.append(depth) + grid_image = make_grid(depths, depth_layer_indices) + return grid_image + +def get_seg_images(out, image): + seg_embs = out.seg_embs + + if len(seg_embs) == 0: + return None + + seg_preds = [] + inputs = oneformer_processor(image, ["semantic"], return_tensors="pt") + inputs["pixel_values"] = inputs["pixel_values"].to(out.logits.device, out.logits.dtype) + inputs["task_inputs"] = inputs["task_inputs"].to(out.logits.device, out.logits.dtype) + backbone_features = oneformer.get_backbone_feats(**inputs) + for i, seg_emb in enumerate(seg_embs): + pred = oneformer.get_masks(**inputs, backbone_last_feature=seg_emb.float(), all_backbone_features=backbone_features) + pred = oneformer_processor.post_process_panoptic_segmentation( + pred, target_sizes=[image.size[::-1]] + )[0] + pred_msk, pred_cls = oneformer_prepare_panoptic_instance_prediction(**pred, oneformer=oneformer) + pred = visualize_oneformer_masks_on_image(image, pred_msk, pred_cls) + seg_preds.append(pred) + grid_image = make_grid(seg_preds, seg_layer_indices) + return grid_image + +def delete_text(state, image_process_mode): + state.messages[-1][-1] = None + prev_human_msg = state.messages[-2] + if type(prev_human_msg[1]) in (tuple, list): + prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) + yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + +def regenerate(state, image_process_mode): + state.messages[-1][-1] = None + prev_human_msg = state.messages[-2] + if type(prev_human_msg[1]) in (tuple, list): + prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) + state.skip_next = False + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 + +def get_interm_outs(state): + prompt = state.get_prompt() + images = state.get_images(return_pil=True) + #prompt, image_args = process_image(prompt, images) + + if images is not None and len(images) > 0: + if len(images) > 0: + if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): + raise ValueError("Number of images does not match number of tokens in prompt") + + #images = [load_image_from_base64(image) for image in images] + image_sizes = [image.size for image in images] + inp_images = process_images(images, image_processor, model.config) + + if type(inp_images) is list: + inp_images = [image.to(model.device, dtype=torch.float16) for image in images] + else: + inp_images = inp_images.to(model.device, dtype=torch.float16) + else: + inp_images = None + image_sizes = None + image_args = {"images": inp_images, "image_sizes": image_sizes} + else: + inp_images = None + image_args = {} + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + + interm_outs = model.get_visual_interpretations( + input_ids, + **image_args + ) + + depth_outs = get_depth_images(interm_outs, image_sizes[0]) + seg_outs = get_seg_images(interm_outs, images[0]) + gen_outs = get_gen_images(interm_outs) + + return depth_outs, seg_outs, gen_outs + +# @spaces.GPU +def generate(state, temperature, top_p, max_output_tokens): + prompt = state.get_prompt() + images = state.get_images(return_pil=True) + #prompt, image_args = process_image(prompt, images) + + ori_prompt = prompt + num_image_tokens = 0 + + if images is not None and len(images) > 0: + if len(images) > 0: + if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): + raise ValueError("Number of images does not match number of tokens in prompt") + + #images = [load_image_from_base64(image) for image in images] + image_sizes = [image.size for image in images] + images = process_images(images, image_processor, model.config) + + if type(images) is list: + images = [image.to(model.device, dtype=torch.float16) for image in images] + else: + images = images.to(model.device, dtype=torch.float16) + else: + images = None + image_sizes = None + image_args = {"images": images, "image_sizes": image_sizes} + else: + images = None + image_args = {} + + max_context_length = getattr(model.config, 'max_position_embeddings', 2048) + max_new_tokens = max_output_tokens + do_sample = True if temperature > 0.001 else False + stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2 + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) + + max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) + + if max_new_tokens < 1: + return + + thread = Thread(target=model.generate, kwargs=dict( + inputs=input_ids, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + streamer=streamer, + use_cache=True, + pad_token_id=tokenizer.eos_token_id, + **image_args + )) + thread.start() + generated_text = '' + for new_text in streamer: + generated_text += new_text + if generated_text.endswith(stop_str): + generated_text = generated_text[:-len(stop_str)] + state.messages[-1][-1] = generated_text + yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + + yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5 + + torch.cuda.empty_cache() + +txt = gr.Textbox( + scale=4, + show_label=False, + placeholder="Enter text and press enter.", + container=False, +) + + +title = "

OLA-VLM: Optimizing Language Model Representations for Enhanced Visual Quality and Alignment

" +description = "

Jitesh Jain    Zhengyuan Yang    Humphrey Shi*    Jianfeng Gao*    Jianwei Yang*

" \ + + "

*Equal Advising

" \ + + "

Project Page | Video | ArXiv | Github

" + +tos_markdown = (""" +### Terms of use +By using this service, users are required to agree to the following terms: +The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. +""") + + +learn_more_markdown = (""" +### License +The service is a research preview intended for non-commercial use only, subject to the [License](https://huggingface.co/lmsys/vicuna-7b-v1.5) of Vicuna-v1.5, [License](https://github.com/haotian-liu/LLaVA/blob/main/LICENSE) of LLaVA, [Terms of Use](https://cocodataset.org/#termsofuse) of the COCO dataset, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. +""") + +block_css = """ +#buttons button { + min-width: min(120px,100%); +} +""" + + +textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) +with gr.Blocks(title="OLA-VLM", theme=gr.themes.Default(), css=block_css) as demo: + state = gr.State() + + gr.Markdown(title) + gr.Markdown(description) + + with gr.Row(): + with gr.Column(scale=4): + imagebox = gr.Image(label="Input Image", type="filepath") + image_process_mode = gr.Radio( + ["Crop", "Resize", "Pad", "Default"], + value="Default", + label="Preprocess for non-square image", visible=False) + + # with gr.Accordion("Parameters", open=False) as parameter_row: + with gr.Row(): + temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) + top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) + max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) + + with gr.Column(scale=8): + chatbot = gr.Chatbot( + elem_id="chatbot", + label="OLA-VLM", + height=300, + layout="panel", + ) + textbox.render() + with gr.Row(elem_id="buttons") as button_row: + upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False, visible=False) + downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False, visible=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False, visible=False) + #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) + regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False) + clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False) + submit_btn = gr.Button(value="Send", variant="primary") + + with gr.Accordion("Representations from selected layers of the LLM (expects only a single image input)", open=False) as interm_out: + inter_vis_btn = gr.Button(value="✨ Visualize") + with gr.Row(): + depth_box = gr.Image(label="depth", type="pil", visible=True) + seg_box = gr.Image(label="seg", type="pil", visible=True) + gen_box = gr.Image(label="gen", type="pil", visible=True) + + gr.Examples(examples=[ + [f"assets/cars.jpg", "Which car is in front: the blue or the brown one?"], + [f"assets/pb.jpg", "Where is the bulding located with respect to the man?"], + ], inputs=[imagebox, textbox], cache_examples=False) + + # gr.Markdown(tos_markdown) + # gr.Markdown(learn_more_markdown) + # url_params = gr.JSON(visible=False) + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + + inter_vis_btn.click( + get_interm_outs, + [state], + [depth_box, seg_box, gen_box], + ) + + clear_btn.click( + clear_history, + None, + [state, chatbot, textbox, imagebox, depth_box, gen_box, seg_box] + btn_list, + queue=False + ) + + regenerate_btn.click( + delete_text, + [state, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + generate, + [state, temperature, top_p, max_output_tokens], + [state, chatbot, textbox, imagebox] + btn_list, + ) + textbox.submit( + add_text, + [state, imagebox, textbox, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + generate, + [state, temperature, top_p, max_output_tokens], + [state, chatbot, textbox, imagebox] + btn_list, + ) + + submit_btn.click( + add_text, + [state, imagebox, textbox, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + generate, + [state, temperature, top_p, max_output_tokens], + [state, chatbot, textbox, imagebox] + btn_list, + ) + +demo.queue( + status_update_rate=10, + api_open=False +).launch(share=True) +demo.queue() \ No newline at end of file diff --git a/ola_vlm/.DS_Store b/ola_vlm/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..555ff6b18feda42a03c3cd994630b6a093666069 Binary files /dev/null and b/ola_vlm/.DS_Store differ diff --git a/ola_vlm/__init__.py b/ola_vlm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2616af5982969d7a48c0111a97ca0d9876f130be --- /dev/null +++ b/ola_vlm/__init__.py @@ -0,0 +1,2 @@ +from .model import LlavaLlamaForCausalLM +from .model import LlavaPhi3ForCausalLM diff --git a/ola_vlm/constants.py b/ola_vlm/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..374be090510b302de9882d880c755787a8eafe11 --- /dev/null +++ b/ola_vlm/constants.py @@ -0,0 +1,13 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = -200 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" +IMAGE_PLACEHOLDER = "" diff --git a/ola_vlm/conversation.py b/ola_vlm/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..0c67b4a7e77cd8a575c5363b01835334833f8e73 --- /dev/null +++ b/ola_vlm/conversation.py @@ -0,0 +1,255 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Tuple +import base64 +from io import BytesIO +from PIL import Image + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + MPT = auto() + PLAIN = auto() + LLAMA_3 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + if len(messages) > 0 and type(messages[0][1]) is tuple: + messages = self.messages.copy() + init_role, init_msg = messages[0].copy() + init_msg = init_msg[0].replace("", "").strip() + if 'mmtag' in self.version: + messages[0] = (init_role, init_msg) + messages.insert(0, (self.roles[0], "")) + messages.insert(1, (self.roles[1], "Received.")) + else: + messages[0] = (init_role, "\n" + init_msg) + + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if self.system == "" else self.system + self.sep + "\n" + for role, message in messages: + if message: + if type(message) is tuple: + message, images, _ = message + message = "" * len(images) + message + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + if isinstance(self.messages, tuple): + self.messages = list(self.messages) + self.messages.append([role, message]) + + def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672): + if image_process_mode == "Pad": + def expand2square(pil_img, background_color=(122, 116, 104)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image) + elif image_process_mode in ["Default", "Crop"]: + pass + elif image_process_mode == "Resize": + image = image.resize((336, 336)) + else: + raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + if max(image.size) > max_len: + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + if return_pil: + return image + else: + buffered = BytesIO() + image.save(buffered, format=image_format) + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + return img_b64_str + + def get_images(self, return_pil=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + msg, image, image_process_mode = msg + image = self.process_image(image, image_process_mode, return_pil=return_pil) + images.append(image) + return images + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + msg, image, image_process_mode = msg + img_b64_str = self.process_image( + image, "Default", return_pil=False, + image_format='JPEG') + img_str = f'user upload image' + msg = img_str + msg.replace('', '').strip() + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llava_llama_3 = Conversation( + system="""<|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""", + roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"), + version="llama3", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|eot_id|>", +) + +conv_llava_phi_3 = Conversation( + system="""<|system|>\nYou are a helpful AI assistant.""", + roles=("\n<|user|>\n", "\n<|assistant|>\n"), + version="phi3", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|end|>", +) + +default_conversation = conv_llava_phi_3 +conv_templates = { + "v1": conv_vicuna_v1, + "vicuna_v1": conv_vicuna_v1, + "llava_phi_3": conv_llava_phi_3, + "llava_llama_3": conv_llava_llama_3, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/ola_vlm/eval/.DS_Store b/ola_vlm/eval/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e2a331e56586492f03d4e3352d6248a8c4862c5c Binary files /dev/null and b/ola_vlm/eval/.DS_Store differ diff --git a/ola_vlm/eval/eval_cv_bench.py b/ola_vlm/eval/eval_cv_bench.py new file mode 100644 index 0000000000000000000000000000000000000000..c9363884b9f58a78494f64567c427de912490297 --- /dev/null +++ b/ola_vlm/eval/eval_cv_bench.py @@ -0,0 +1,78 @@ +import pandas as pd +import json +import argparse + +def load_jsonl(f): + lines = open(f, encoding='utf-8').readlines() + lines = [x.strip() for x in lines] + if lines[-1] == '': + lines = lines[:-1] + data = [json.loads(x) for x in lines] + return data + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument("--results_file", type=str, default="cv-bench_answer.jsonl") + args = parser.parse_args() + + answers = load_jsonl(args.results_file) + + data = { + "source": [], + "result": [], + "task": [], + } + import re + for a in answers: + data["source"].append(a["source"][0]) + if "(" in a["prediction"]: + match = re.search(r'\(([A-Z])\)', a["prediction"]) + if match: + pred = "(" + match.group(1) + ")" + else: + pred = "(" + a["prediction"][0] + ")" + data["result"].append(pred == a["answer"][0]) + data["task"].append(a["task"][0]) + + df = pd.DataFrame(data) + + def calculate_accuracy(df, source): + source_df = df[df['source'] == source] + accuracy = (source_df['result']).mean() + return accuracy + + def calculate_task_accuracy(df, task): + source_df = df[df['task'] == task] + accuracy = (source_df['result']).mean() + return accuracy + + accuracy_2d_ade = calculate_accuracy(df, 'ADE20K') + accuracy_2d_coco = calculate_accuracy(df, 'COCO') + accuracy_3d_omni = calculate_accuracy(df, 'Omni3D') + + tasks = ["Count", "Depth", "Relation", "Distance"] + + scores = {} + + accuracy_2d = (accuracy_2d_ade + accuracy_2d_coco) / 2 + accuracy_3d = accuracy_3d_omni + + combined_accuracy = (accuracy_2d + accuracy_3d) / 2 + + scores["Overall"] = combined_accuracy + + scores["3D"] = accuracy_3d + scores["2D"] = accuracy_2d + + for t in tasks: + accuracy = calculate_task_accuracy(df, t) + scores[t] = accuracy + + print("\n=========================CV-Bench Scores===============================") + for key, value in scores.items(): + print(f"{key} -> {value}") + print("================================================================") + + with open(args.results_file.replace('.jsonl', '_score.json'), "w") as f: + json.dump(scores, f, indent=2) \ No newline at end of file diff --git a/ola_vlm/eval/eval_mmstar.py b/ola_vlm/eval/eval_mmstar.py new file mode 100644 index 0000000000000000000000000000000000000000..0c1ab1b9414d40a8f773b5e1c59dbe7ec4fde96a --- /dev/null +++ b/ola_vlm/eval/eval_mmstar.py @@ -0,0 +1,17 @@ +import os +import argparse +import json + +from ola_vlm.eval.mmstar.evaluate import MMStar_eval + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--results_file', type=str, default="./playground/data/eval/mmstar_results.jsonl") + return parser.parse_args() + + +if __name__ == '__main__': + + args = parse_args() + MMStar_eval(args.results_file) diff --git a/ola_vlm/eval/eval_probe_task.py b/ola_vlm/eval/eval_probe_task.py new file mode 100644 index 0000000000000000000000000000000000000000..e18e61e8cd3142daa0b67a04a4fbc5a5acbf8d9b --- /dev/null +++ b/ola_vlm/eval/eval_probe_task.py @@ -0,0 +1,223 @@ +import argparse +import torch + +from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from ola_vlm.conversation import conv_templates +from ola_vlm.model.builder import load_pretrained_model +from ola_vlm.utils import disable_torch_init +from ola_vlm.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path +from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead +from transformers import OneFormerProcessor + +from PIL import Image +import json +import os +from tqdm import tqdm +from icecream import ic +import warnings +warnings.filterwarnings("ignore") +import random +import numpy as np +from analyze.analyze_utils import prepare_coco, prepare_da2k +import math +from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers import DPMSolverMultistepScheduler + + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def load_image(image_file): + image = Image.open(image_file).convert('RGB') + return image + +import glob + +def list_image_files(directory): + image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.gif', '*.bmp', '*.tiff'] + image_files = [] + for extension in image_extensions: + image_files.extend(glob.glob(os.path.join(directory, extension))) + return image_files + +def prep_seginw(dir): + image_files = list_image_files(dir) + prompts = [] + for image_file in image_files: + prompts.append("Describe the image") + return image_files, prompts, prompts + +def predict(args): + + mode = args.mode + + name = args.model_path.split("/")[-1] + os.makedirs(f"plots/probes_task/{name}/", exist_ok=True) + + # Model + disable_torch_init() + + if mode == 'gen' or mode == 'seg': + images, prompts, answers = prepare_coco(args.json_file) + elif mode == 'depth': + images, prompts, answers = prepare_da2k("/mnt/vlpdatasets/sherlock/eval/DA-2K/DA-2K/images", is_eval=True) + + images = get_chunk(images, args.num_chunks, args.chunk_idx) + prompts = get_chunk(prompts, args.num_chunks, args.chunk_idx) + answers = get_chunk(answers, args.num_chunks, args.chunk_idx) + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + + if mode == "gen": + pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"playground/jiteshjain_sherlock/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16") + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to("cuda") + + elif mode == "seg": + oneformer_processor = OneFormerProcessor.from_pretrained("/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large") + oneformer = OneFormerHead.from_pretrained("/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large") + oneformer = oneformer.to("cuda") + + if "mistral" in model_name.lower(): + conv_mode = "mistral_instruct" + elif "v1.6-34b" in model_name.lower(): + conv_mode = "chatml_direct" + elif "llama3" in model_name.lower(): + conv_mode = "llava_llama_3" + elif "qwen" in model_name.lower(): + conv_mode = "qwen_1_5" + elif "v1" in model_name.lower(): + conv_mode = "llava_v1" + elif "phi" in model_name.lower(): + conv_mode = "llava_phi_3" + + set_seed(42) + + if mode == "gen": + try: + layers = model.config.image_gen["layer_indices"] + except: + layers = [i+1 for i in range(32)] + elif mode == "depth": + try: + layers = model.config.image_depth["layer_indices"] + except: + layers = [i+1 for i in range(32)] + elif mode == "seg": + try: + layers = model.config.image_seg["layer_indices"] + except: + layers = [i+1 for i in range(32)] + + from tqdm import tqdm + for fname, prompt, answer in tqdm(zip(images, prompts, answers), total=len(prompts)): + + conv = conv_templates[conv_mode].copy() + im = fname.split("/")[-1].split(".")[0] + + image = load_image(fname) + + image_size = image.size + image_tensor = process_images([image], image_processor, model.config) + if type(image_tensor) is list: + image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] + else: + image_tensor = image_tensor.to(model.device, dtype=torch.float16) + + inp = prompt + if image is not None: + if model.config.mm_use_im_start_end: + inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp + else: + inp = DEFAULT_IMAGE_TOKEN + '\n' + inp + + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + + with torch.inference_mode(): + out = model.get_visual_interpretations( + input_ids, + images=image_tensor, + image_sizes=image_size, + ) + + if mode == "seg": + seg_embs = out.seg_embs + inputs = oneformer_processor(image, ["semantic"], return_tensors="pt") + inputs["pixel_values"] = inputs["pixel_values"].to(out.logits.device, out.logits.dtype) + inputs["task_inputs"] = inputs["task_inputs"].to(out.logits.device, out.logits.dtype) + backbone_features = oneformer.get_backbone_feats(**inputs) + for i, seg_emb in enumerate(seg_embs): + pred = oneformer.get_masks(**inputs, backbone_last_feature=seg_emb.float(), all_backbone_features=backbone_features) + pred = oneformer_processor.post_process_semantic_segmentation( + pred, target_sizes=[image.size[::-1]] + )[0] + pred = pred.squeeze().cpu().numpy().astype(np.uint8) + pred = Image.fromarray(pred) + if not os.path.exists(f"plots/probes_task/{name}/seg/layer_{layers[i]}"): + os.makedirs(f"plots/probes_task/{name}/seg/layer_{layers[i]}", exist_ok=True) + save_path = os.path.join(f"plots/probes_task/{name}/seg/layer_{layers[i]}", fname.split("/")[-1].replace("jpg", "png")) + pred.save(save_path) + + + elif mode == "gen": + img_embeds = out.image_embs + images = [] + + for img_emb in img_embeds: + gen_image = pipe(image_embeds=img_emb.squeeze(1), + num_inference_steps=25, + ).images[0] + images.append(gen_image) + + for i, image in enumerate(images): + image = image.resize((256, 256), Image.LANCZOS) + if not os.path.exists(f"plots/probes_task/{name}/gen/layer_{layers[i]}"): + os.makedirs(f"plots/probes_task/{name}/gen/layer_{layers[i]}", exist_ok=True) + save_path = os.path.join(f"plots/probes_task/{name}/gen/layer_{layers[i]}", fname.split("/")[-1]) + image.save(save_path) + + elif mode == "depth": + depth_preds = out.depth_preds + + for i, depth_pred in enumerate(depth_preds): + if not os.path.exists(f"plots/probes_task/{name}/depth/layer_{layers[i]}"): + os.makedirs(f"plots/probes_task/{name}/depth/layer_{layers[i]}", exist_ok=True) + depth = depth_pred.squeeze(0).cpu().numpy() * 255.0 + depth = depth.astype(np.uint8) + depth = Image.fromarray(depth) + save_path = os.path.join(f"plots/probes_task/{name}/depth/layer_{layers[i]}", fname.split("/")[-1]) + depth.save(save_path) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="/mnt/projects4jw/jiteshjain_sherlock/llava-v1.5-7b") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--json-file", type=str, default="/mnt/projects4jw/jiteshjain_sherlock/datasets/coco/annotations/captions_val2017.json") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=10) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--mode", type=str, default="gen") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + args = parser.parse_args() + predict(args) diff --git a/ola_vlm/eval/eval_sherlock_dsg.py b/ola_vlm/eval/eval_sherlock_dsg.py new file mode 100644 index 0000000000000000000000000000000000000000..44e55e712815384ab75a04db773fb7d7572c502f --- /dev/null +++ b/ola_vlm/eval/eval_sherlock_dsg.py @@ -0,0 +1,282 @@ +import argparse +import torch + +from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from ola_vlm.conversation import conv_templates +from ola_vlm.model.builder import load_pretrained_model +from ola_vlm.utils import disable_torch_init +from ola_vlm.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path +from ola_vlm.model.aux_heads.sam_utils.build_sam import sam_model_registry +from ola_vlm.model.aux_heads.sam_utils.automatic_mask_generator import SamAutomaticMaskGenerator +from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead, OneFormerSegHead, OneFormerTaskTokenSegHead +from ola_vlm.model.aux_heads.depth_anything_v2.dpt import DepthAnythingV2 +from transformers import OneFormerProcessor + +from diffusers import ( + DPMSolverMultistepScheduler, + StableUnCLIPImg2ImgPipeline, +) + +from PIL import Image +import json +import os +from tqdm import tqdm +from icecream import ic +import warnings +warnings.filterwarnings("ignore") +import random +import numpy as np +from analyze.analyze_utils import prepare_coco +import math + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def load_image(image_file): + image = Image.open(image_file).convert('RGB') + return image + +import glob + +def list_image_files(directory): + image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.gif', '*.bmp', '*.tiff'] + image_files = [] + for extension in image_extensions: + image_files.extend(glob.glob(os.path.join(directory, extension))) + return image_files + +def get_gen_feats(pipe, image): + with torch.no_grad(): + clip_ims = pipe.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") + feat = pipe.image_encoder(clip_ims).image_embeds + return feat + +def get_dav2_feats(dav2, image): + image = image.resize((336, 336)) + image = np.array(image) + with torch.no_grad(): + feat = dav2.infer_image(image, is_dsg=True) + return feat[-1][0] + +def get_seg_feats(mask_generator, oneformer, oneformer_processor, seg_teacher, image): + if seg_teacher == "oneformer": + img = image.resize((768, 768)) + inputs = oneformer_processor(img, ["panoptic"], return_tensors="pt") + inputs["pixel_values"] = inputs["pixel_values"].to("cuda") + with torch.no_grad(): + feats = oneformer.forward_features(**inputs) + else: + img = np.array(image) + with torch.no_grad(): + mask_generator.predictor.set_image(img) + feats = mask_generator.predictor.features + mask_generator.predictor.reset_image() + return feats + + +def predict(args): + + mode = args.mode + + name = args.model_path.split("/")[-1] + os.makedirs(f"plots/probe_scores/{name}/", exist_ok=True) + + if "cambrian" in name: + from ola_vlm.cambrian.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + from ola_vlm.cambrian.conversation import conv_templates, SeparatorStyle + from ola_vlm.cambrian.model.builder import load_pretrained_model + from ola_vlm.cambrian.utils import disable_torch_init + from ola_vlm.cambrian.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + + disable_torch_init() + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + + if 'llama-2' in model_name.lower(): + conv_mode = "cambrian_llama_2" + elif "v1" in model_name.lower(): + conv_mode = "cambrian_v1" + elif "mpt" in model_name.lower(): + conv_mode = "mpt" + else: + conv_mode = "cambrian_v0" + + else: + from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + from ola_vlm.conversation import conv_templates + from ola_vlm.model.builder import load_pretrained_model + from ola_vlm.utils import disable_torch_init + from ola_vlm.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path + + disable_torch_init() + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + if "mistral" in model_name.lower(): + conv_mode = "mistral_instruct" + elif "v1.6-34b" in model_name.lower(): + conv_mode = "chatml_direct" + elif "llama3" in model_name.lower(): + conv_mode = "llava_llama_3" + elif "qwen" in model_name.lower(): + conv_mode = "llava_qwen" + elif "v1" in model_name.lower(): + conv_mode = "llava_v1" + elif "phi" in model_name.lower(): + conv_mode = "llava_phi_3" + + images, prompts, answers = prepare_coco(args.json_file) + + images = get_chunk(images, args.num_chunks, args.chunk_idx) + prompts = get_chunk(prompts, args.num_chunks, args.chunk_idx) + answers = get_chunk(answers, args.num_chunks, args.chunk_idx) + + if mode == "gen": + pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"playground/jiteshjain_sherlock/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16") + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to("cuda") + + elif mode == "seg": + oneformer_processor, oneformer, mask_generator = None, None, None + seg_teacher = model.config.image_seg.get("seg_teacher", "sam") + if seg_teacher == "sam": + sam = sam_model_registry["vit_l"](checkpoint="/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large") + sam = sam.to("cuda") + mask_generator = SamAutomaticMaskGenerator(sam.float()) + else: + oneformer_processor = OneFormerProcessor.from_pretrained("/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large") + oneformer = OneFormerHead.from_pretrained("/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large") + oneformer = oneformer.to("cuda") + + elif mode == "depth": + dav2_cfg = {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]} + dav2_backbone = DepthAnythingV2(**dav2_cfg) + dav2_backbone.load_state_dict(torch.load("/mnt/projects4jw/jiteshjain_sherlock/depth_anything_v2_vitl.pth", map_location='cpu')) + dav2_backbone = dav2_backbone.to("cuda") + + + set_seed(42) + + if mode == "gen": + try: + layers = model.config.image_gen["layer_indices"] + except: + layers = [i+1 for i in range(32)] + elif mode == "depth": + try: + layers = model.config.image_depth["layer_indices"] + except: + layers = [i+1 for i in range(32)] + elif mode == "seg": + try: + layers = model.config.image_seg["layer_indices"] + except: + layers = [i+1 for i in range(32)] + + + os.makedirs(f"plots/probe_scores/{name}/{mode}/", exist_ok=True) + + if os.path.exists(f"plots/probe_scores/{name}/{mode}/{args.num_chunks}_{args.chunk_idx}.json"): + with open(f"plots/probe_scores/{name}/{mode}/{args.num_chunks}_{args.chunk_idx}.json", 'r') as f: + diff_dict = json.load(f) + else: + diff_dict = {} + + i = 0 + from tqdm import tqdm + for fname, prompt, answer in tqdm(zip(images, prompts, answers), total=len(prompts)): + + # if fname.split("/")[-1] in diff_dict.keys(): + # continue + + conv = conv_templates[conv_mode].copy() + image = load_image(fname) + image = image.resize((640, 640)) + + image_size = image.size + + image_tensor = process_images([image], image_processor, model.config) + if type(image_tensor) is list: + image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] + else: + image_tensor = image_tensor.to(model.device, dtype=torch.float16) + + inp = prompt + if image is not None: + if model.config.mm_use_im_start_end: + inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp + else: + inp = DEFAULT_IMAGE_TOKEN + '\n' + inp + + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + + with torch.inference_mode(): + out = model.get_visual_interpretations( + input_ids, + images=image_tensor, + image_sizes=[image_size], + ) + + if mode == "gen": + embeds = out.image_embs + feats = get_gen_feats(pipe, image) + elif mode == "depth": + embeds = out.depth_embs + embeds = [emb[0][0] for emb in embeds] + feats = get_dav2_feats(dav2_backbone, image) + elif mode == "seg": + embeds = out.seg_embs + feats = get_seg_feats(mask_generator, oneformer, oneformer_processor, seg_teacher, image) + + layer_diff = {} + for i, emb in enumerate(embeds): + emb = emb.to("cuda") + layer_diff[layers[i]] = torch.nn.CosineEmbeddingLoss(reduction="mean")( + emb.reshape(1, -1).float(), feats.reshape(1, -1).float(), + torch.ones(len(emb)).to(feats.device) + ).cpu().item() + from icecream import ic + ic(layer_diff[layers[i]]) + diff_dict[fname.split("/")[-1]] = layer_diff + + if i % 200 == 0: + # Save progress intermittently + with open(f"plots/probe_scores/{name}/{mode}/{args.num_chunks}_{args.chunk_idx}.json", 'w') as f: + json.dump(diff_dict, f, indent=2) + + i += 1 + + with open(f"plots/probe_scores/{name}/{mode}/{args.num_chunks}_{args.chunk_idx}.json", 'w') as f: + json.dump(diff_dict, f, indent=2) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="/mnt/projects4jw/jiteshjain_sherlock/llava-v1.5-7b") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--json-file", type=str, default="/mnt/projects4jw/jiteshjain_sherlock/datasets/coco/annotations/captions_val2017.json") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=10) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--mode", type=str, default="gen") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + args = parser.parse_args() + predict(args) diff --git a/ola_vlm/eval/get_all_stats.py b/ola_vlm/eval/get_all_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..7739fb1254c2a8107ba03bb9fcef696e07b30ef1 --- /dev/null +++ b/ola_vlm/eval/get_all_stats.py @@ -0,0 +1,132 @@ +import json +import argparse +from icecream import ic +import os +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--results_folder", type=str, default="./playground/data/eval/results") + parser.add_argument("--ckpt", type=str) + args = parser.parse_args() + + scores = {} + + dirs = os.listdir(f"{args.results_folder}/{args.ckpt}") + for dir in dirs: + if args.ckpt in dir and dir not in args.ckpt: + break + + + try: + with open(f"{args.results_folder}/{args.ckpt}/mmstar/merge_score.json", "r") as f: + data = json.load(f) + scores["MMStar"] = round(data.get("final score", 0)*100, 1) if data.get("final score") is not None else None + except: + scores["MMStar"] = None + + cv_scores = {} + + with open(f"{args.results_folder}/{args.ckpt}/cv-bench/merge_score.json", "r") as f: + data = json.load(f) + scores["CV-Bench"] = round(data.get("Overall", 0)*100, 1) if data.get("Overall") is not None else None + cv_scores["CV-Bench (2D)"] = round(data.get("2D", 0)*100, 1) if data.get("2D") is not None else None + cv_scores["CV-Bench (3D)"] = round(data.get("3D", 0)*100, 1) if data.get("3D") is not None else None + cv_scores["CV-Bench (Count)"] = round(data.get("Count", 0)*100, 1) if data.get("Count") is not None else None + cv_scores["CV-Bench (Depth)"] = round(data.get("Depth", 0)*100, 1) if data.get("Depth") is not None else None + cv_scores["CV-Bench (Relation)"] = round(data.get("Relation", 0)*100, 1) if data.get("Relation") is not None else None + cv_scores["CV-Bench (Distance)"] = round(data.get("Distance", 0)*100, 1) if data.get("Distance") is not None else None + + + with open(f"{args.results_folder}/{args.ckpt}/{dir}/results.json", "r") as f: + results = json.load(f).get("results", {}) + # scores["MME-Cognition"] = round(results.get("mme", {}).get("mme_cognition_score,none", 0), 1) if results.get("mme", {}).get("mme_cognition_score,none") is not None else None + # scores["MME-Perception"] = round(results.get("mme", {}).get("mme_percetion_score,none", 0), 1) if results.get("mme", {}).get("mme_percetion_score,none") is not None else None + + scores["Realworld-QA"] = round(results.get("realworldqa", {}).get("exact_match,flexible-extract", 0)*100, 1) if results.get("realworldqa", {}).get("exact_match,flexible-extract") is not None else None + scores["VizWiz-VQA-Val"] = round(results.get("vizwiz_vqa_val", {}).get("exact_match,none", 0)*100, 1) if results.get("vizwiz_vqa_val", {}).get("exact_match,none") is not None else None + # scores["SEEDBench-Image"] = round(results.get("seedbench", {}).get("seed_image,none", 0)*100, 1) if results.get("seedbench", {}).get("seed_image,none") is not None else None + # scores["VQAv2-Val"] = round(results.get("vqav2_val", {}).get("exact_match,none", 0)*100, 1) if results.get("vqav2_val", {}).get("exact_match,none") is not None else None + + # scores["Science-QA-Img"] = round(results.get("scienceqa_img", {}).get("exact_match,none", 0)*100, 1) if results.get("scienceqa_img", {}).get("exact_match,none") is not None else None + scores["MMMU-Val"] = round(results.get("mmmu_val", {}).get("mmmu_acc,none", 0)*100, 1) if results.get("mmmu_val", {}).get("mmmu_acc,none") is not None else None + # scores["MMBench"] = round(results.get("mmbench_en_dev", {}).get("gpt_eval_score,none", 0), 1) if results.get("mmbench_en_dev", {}).get("gpt_eval_score,none") is not None else None + + # scores["NaturalBench"] = round(results.get("naturalbench", {}).get("mme_score,none", 0)*100, 1) if results.get("naturalbench", {}).get("mme_score,none") is not None else None + + # scores["GQA"] = round(results.get("gqa", {}).get("exact_match,none", 0)*100, 1) if results.get("gqa", {}).get("exact_match,none") is not None else None + scores["POPE"] = round(results.get("pope", {}).get("pope_accuracy,none", 0)*100, 1) if results.get("pope", {}).get("pope_accuracy,none") is not None else None + scores["MMVet"] = round(results.get("mmvet", {}).get("gpt_eval_score", 0)*100, 1) if results.get("mmvet", {}).get("gpt_eval_score") is not None else None + scores["OK-VQA"] = round(results.get("ok_vqa", {}).get("exact_match,none", 0)*100, 1) if results.get("ok_vqa", {}).get("exact_match,none") is not None else None + # scores["ChartQA"] = round(results.get("chartqa", {}).get("relaxed_overall,none", 0)*100, 1) if results.get("chartqa", {}).get("relaxed_overall,none") is not None else None + # scores["DocVQA"] = round(results.get("docvqa_val", {}).get("anls,none", 0)*100, 1) if results.get("docvqa_val", {}).get("anls,none") is not None else None + # scores["TextVQA"] = round(results.get("textvqa_val", {}).get("exact_match,none", 0)*100, 1) if results.get("textvqa_val", {}).get("exact_match,none") is not None else None + + try: + with open(f"{args.results_folder}/{args.ckpt}/mmvp/merge_score.json", "r") as f: + data = json.load(f) + scores["MMVP"] = round(data.get("mmvp", 0)*100, 1) if data.get("mmvp") is not None else None + except: + scores["MMVP"] = None + + keys = list(scores.keys()) + str_scores = [str(scores[key]) if scores[key] is not None else 'None' for key in keys] + + abl_keys = ["CV-Bench", "MMStar", "VizWiz-VQA-Val", "MMVet", "MMVP", "MMMU-Val"] + + abl_scores = [scores[key] for key in abl_keys if scores[key] is not None] + + small_abl_keys = ["CV-Bench", "MMStar", "OK-VQA", "MMMU-Val"] + small_abl_scores = [scores[key] for key in small_abl_keys if scores[key] is not None] + + cv_bench_keys = ["CV-Bench (2D)", "CV-Bench (3D)", "CV-Bench (Count)", "CV-Bench (Depth)", "CV-Bench (Relation)", "CV-Bench (Distance)"] + cv_bench_scores = [cv_scores[key] for key in cv_bench_keys if cv_scores[key] is not None] + + # cat_scores = {} + # if os.path.exists(f"{args.results_folder}/{args.ckpt}/categorized_scores.json"): + # with open(f"{args.results_folder}/{args.ckpt}/categorized_scores.json", "r") as f: + # cat_scores = json.load(f) + # cat_scores.pop("Both") + + print("\n====================All-Scores===========================================") + print(" & ".join(keys)) + print(" & ".join(str_scores)) + if abl_scores: + print("\n====================Abl-Scores===========================================") + print(" & ".join(abl_keys)) + print(" & ".join([str(a) for a in abl_scores])) + print(f"Ablation Avg: {round(np.mean(abl_scores), 1)}") + else: + print("Ablation Avg: None") + + if small_abl_scores: + print("\n====================Small-Abl-Scores===========================================") + print(" & ".join(small_abl_keys)) + print(" & ".join([str(a) for a in small_abl_scores])) + print(f"Small-Ablation Avg: {round(np.mean(small_abl_scores), 1)}") + else: + print("Small-Ablation Avg: None") + + if cv_bench_scores: + print("\n====================CV-Bench-Scores===========================================") + print(" & ".join(cv_bench_keys)) + print(" & ".join([str(c) for c in cv_bench_scores])) + print(f"CV-Bench Overall: {round(np.mean(cv_bench_scores[:2]), 1)}") + else: + print("CV-Bench Avg: None") + + # if cat_scores is not None: + # print("\n====================Categorized-Scores===========================================") + # cats = [] + # class_scores = [] + # benches = [] + # for k, v in cat_scores.items(): + # cats.append(k) + # for bench, score in v.items(): + # benches.append(bench) + # class_scores.append(round(score*100, 1)) + # print(" & ".join(cats)) + # print(" & ".join(benches)) + # print(" & ".join([str(c) for c in class_scores])) + # print("================================================================") diff --git a/ola_vlm/eval/get_probe_task_scores.py b/ola_vlm/eval/get_probe_task_scores.py new file mode 100644 index 0000000000000000000000000000000000000000..a9030345f358de7f0b54c6347ef264768eae032e --- /dev/null +++ b/ola_vlm/eval/get_probe_task_scores.py @@ -0,0 +1,197 @@ +import argparse +import torch +from PIL import Image +import json +import os +from tqdm import tqdm +import warnings +import random +import numpy as np +import multiprocessing as mp +from ola_vlm.eval.probe_metrics.fid_score import compute_fid +from analyze.analyze_utils import prepare_coco, prepare_da2k, parse_json +from multiprocessing import Pool +warnings.filterwarnings("ignore") + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def load_image(image_file): + image = Image.open(image_file) + return image + +def mask_iou(gt, pred): + gt = np.array(gt).astype(np.uint8) + pred = np.array(pred).astype(np.uint8) + + iou_scores = [] + for category in np.unique(gt): + if category == 255: + continue + gt_mask = (gt == category) + pred_mask = (pred == category) + + intersection = np.logical_and(gt_mask, pred_mask) + union = np.logical_or(gt_mask, pred_mask) + if np.sum(union) == 0: + iou_scores.append(1.0) + else: + iou_scores.append(np.sum(intersection) / np.sum(union)) + + return np.mean(iou_scores) + +def load_json(path): + with open(path) as f: + data = json.load(f) + return data + +# Helper function for multiprocessing in evaluate_seg +def process_iou(args): + gt_path, layer_folder, dir, fname = args + gt_data = load_image(os.path.join(gt_path, fname.replace("jpg", "png"))) + pred = load_image(os.path.join(layer_folder, dir, fname)) + return mask_iou(gt_data, pred) + +def evaluate_seg(args): + images, _, _ = prepare_coco("/mnt/vlpdatasets/coco/annotations/captions_val2017.json") + fnames = [img.split("/")[-1] for img in images][:8] + + name = args.ckpt + gt_path = "/mnt/vlpdatasets/sherlock/eval/coco/annotations/panoptic_semseg_val2017" + layer_folder = f"plots/probes_task/{name}/seg" + + scores = {"m_iou": []} + dirs = os.listdir(layer_folder) + + with mp.Pool() as pool: + for dir in dirs: + print(f"Evaluating mask iou for {dir}") + args_list = [(gt_path, layer_folder, dir, fname) for fname in fnames] + m_iou = list(tqdm(pool.imap(process_iou, args_list), total=len(args_list), desc=f"Processing {dir}")) + scores["m_iou"].append({dir: round(np.mean(m_iou) * 100, 2)}) + + return scores + +# Helper function for multiprocessing in evaluate_depth +def process_depth(args): + depth_map, point_1, point_2, answer = args + return score_points(depth_map, point_1, point_2, answer) + +def score_points(depth_map, point_1, point_2, answer): + pt1_depth = depth_map[point_1[0], point_1[1]] + pt2_depth = depth_map[point_2[0], point_2[1]] + + if isinstance(pt1_depth, np.ndarray): + pt1_depth = pt1_depth.mean() + if isinstance(pt2_depth, np.ndarray): + pt2_depth = pt2_depth.mean() + + return (answer == "point2") if pt1_depth < pt2_depth else (answer == "point1") + +def load_and_process_image(args): + folder, fname, entry = args + gt_path = os.path.join("/mnt/vlpdatasets/sherlock/plots/dav2_da2k", fname.split("/")[-1].split(".")[0] + ".jpg") + pred_path = os.path.join(folder, fname.split("/")[-1]) + + gt = load_image(gt_path) + pred = load_image(pred_path) + pred = pred.resize(gt.size) + pred = np.array(pred) / 255.0 + + # Process depth for each entry within the image + return [process_depth((pred, entry["point1"], entry["point2"], entry["closer_point"])) for entry in entry["entries"]] + +def score_da2k_parallel(folder, anns): + pred_scores = [] + tasks = [(folder, fname, {"entries": entries}) for fname, entries in anns.items()] + + with Pool() as pool: + results = list(tqdm(pool.imap(load_and_process_image, tasks), total=len(tasks), desc="Processing images")) + for res in results: + if res is not None: + pred_scores.extend(res) + + return np.mean(pred_scores) if pred_scores else 0 + +def evaluate_depth(args): + anns = parse_json("/mnt/vlpdatasets/sherlock/eval/DA-2K/DA-2K/annotations.json") + + name = args.ckpt + layer_folder = f"plots/probes_task/{name}/depth" + + scores = {"da2k_acc": []} + dirs = os.listdir(layer_folder) + + for dir in dirs: + print(f"Evaluating da2k_acc for {dir}") + pred_scores = score_da2k_parallel(os.path.join(layer_folder, dir), anns) + scores["da2k_acc"].append({dir: round(pred_scores * 100, 2)}) + + return scores + +def evaluate_fid(args): + name = args.ckpt + gt_path = os.path.join("plots/coco_gt") + layer_folder = f"plots/probes_task/{name}/gen" + + scores = {"fid": []} + dirs = os.listdir(layer_folder) + + for dir in dirs: + print(f"Evaluating fid for {dir}") + paths = [gt_path, os.path.join(layer_folder, dir)] + fid_score = compute_fid(paths) + scores["fid"].append({dir.replace("_", "-"): round(fid_score, 2)}) + + return scores + +import re + +def print_sorted_scores(scores, metric_name): + # Extract numeric part from layer names for sorting + sorted_scores = sorted(scores[metric_name], key=lambda x: int(re.search(r'\d+', list(x.keys())[0]).group())) + + layers = [list(score.keys())[0] for score in sorted_scores] + values = [list(score.values())[0] for score in sorted_scores] + + # Print sorted layers and scores in the requested format + print("\n=========================Results===============================") + print(" & ".join(layers)) + print(" & ".join([f"{value}" for value in values])) + print(f"Average score: {round(np.mean(values), 2)}") + print("================================================================") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt", type=str, default="llava-1.5-7b") + parser.add_argument("--mode", type=str, default="gen") + args = parser.parse_args() + + mode = args.mode + + if mode == "gen": + scores = evaluate_fid(args) + + print("\n=========================FID-Scores===============================") + for score in scores["fid"]: + for key, value in score.items(): + print(f"{key} -> {value}") + print("================================================================") + + elif mode == "seg": + scores = evaluate_seg(args) + + print("\n=========================Mask-IOU===============================") + print_sorted_scores(scores, "m_iou") + + elif mode == "depth": + scores = evaluate_depth(args) + + print("\n=========================DA2K-Acc===============================") + print_sorted_scores(scores, "da2k_acc") + + else: + print("Invalid mode. Choose from [gen, seg, depth]") diff --git a/ola_vlm/eval/get_sherlock_dsg_scores.py b/ola_vlm/eval/get_sherlock_dsg_scores.py new file mode 100644 index 0000000000000000000000000000000000000000..dea68ca0eb168f33f4bf24b7dccdac5b4b0db362 --- /dev/null +++ b/ola_vlm/eval/get_sherlock_dsg_scores.py @@ -0,0 +1,49 @@ +import argparse +import torch + +import json +import os +from tqdm import tqdm +from icecream import ic +import warnings +warnings.filterwarnings("ignore") +import random +import numpy as np + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt", type=str, default="llava-1.5-7b") + parser.add_argument("--mode", type=str, default="gen") + args = parser.parse_args() + + mode = args.mode + name = args.ckpt.split("/")[-1] + + with open(f'plots/probe_scores/{name}/{args.mode}.json') as file: + scores = json.load(file) + + layer_scores = {} + + for img, v in tqdm(scores.items()): + for layer, score in v.items(): + if layer not in layer_scores: + layer_scores[layer] = [] + layer_scores[layer].append(score) + + for layer, scores in layer_scores.items(): + layer_scores[layer] = np.mean(scores) + + with open(f"plots/probe_scores/{name}/{mode}_scores.json", "w") as f: + json.dump(layer_scores, f, indent=2) + + print(f"================Scores: {mode}===============") + for layer, score in layer_scores.items(): + print(f"Layer: {layer}, Score: {score}") + print("===========================================") \ No newline at end of file diff --git a/ola_vlm/eval/merge_json.py b/ola_vlm/eval/merge_json.py new file mode 100644 index 0000000000000000000000000000000000000000..ad1a1b64b41cc6ae1f7ccd053ceae5b59a20350b --- /dev/null +++ b/ola_vlm/eval/merge_json.py @@ -0,0 +1,30 @@ +import os +import json +import argparse + +parser = argparse.ArgumentParser( + description='Probe eval') +parser.add_argument('--ckpt', + help='ckpt', + default='probe_llava-1.5-vicuna-7b-lr-1e-3') +parser.add_argument('--mode', + help='mode', + default='gen') +parser.add_argument("--num-chunks", type=int, default=1) + + +def save_merged_json(data, output_file): + with open(output_file, 'w') as file: + json.dump(data, file, indent=4) + +if __name__ == "__main__": + args = parser.parse_args() + merge_data = {} + name = args.ckpt.split("/")[-1] + + for i in range(args.num_chunks): + with open(f'plots/probe_scores/{name}/{args.mode}/{args.num_chunks}_{i}.json', 'r') as file: + data = json.load(file) + merge_data.update(data) + + save_merged_json(merge_data, f'plots/probe_scores/{name}/{args.mode}.json') \ No newline at end of file diff --git a/ola_vlm/eval/mmstar/evaluate/__init__.py b/ola_vlm/eval/mmstar/evaluate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d666890a65281af4b3c3128147bdc026978ca86 --- /dev/null +++ b/ola_vlm/eval/mmstar/evaluate/__init__.py @@ -0,0 +1 @@ +from .mmstar import MMStar_eval \ No newline at end of file diff --git a/ola_vlm/eval/mmstar/evaluate/__pycache__/__init__.cpython-310.pyc b/ola_vlm/eval/mmstar/evaluate/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cafa575527eca7dde70f723e4de17ee1ddfb1485 Binary files /dev/null and b/ola_vlm/eval/mmstar/evaluate/__pycache__/__init__.cpython-310.pyc differ diff --git a/ola_vlm/eval/mmstar/evaluate/__pycache__/mmstar.cpython-310.pyc b/ola_vlm/eval/mmstar/evaluate/__pycache__/mmstar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76de69b50be6313d408872f30121b22cc2ef9f53 Binary files /dev/null and b/ola_vlm/eval/mmstar/evaluate/__pycache__/mmstar.cpython-310.pyc differ diff --git a/ola_vlm/eval/mmstar/evaluate/mmstar.py b/ola_vlm/eval/mmstar/evaluate/mmstar.py new file mode 100644 index 0000000000000000000000000000000000000000..0f289617677df823755864d077ddad7b70b02058 --- /dev/null +++ b/ola_vlm/eval/mmstar/evaluate/mmstar.py @@ -0,0 +1,87 @@ +from ola_vlm.eval.mmstar.smp import * +from copy import deepcopy + + +def MMStar_eval(eval_file): + MMStar_score_l2 = { + 'coarse perception': { + 'image scene and topic': 0, + 'image style & quality': 0, + 'image emotion': 0 + }, + 'fine-grained perception': { + 'object counting': 0, + 'recognition': 0, + 'localization': 0 + }, + 'instance reasoning': { + 'single-instance reasoning': 0, + 'cross-instance attribute reasoning': 0, + 'cross-instance relation reasoning': 0 + }, + 'logical reasoning': { + 'code & sequence reasoning': 0, + 'diagram reasoning': 0, + 'common reasoning': 0 + }, + 'science & technology': { + 'biology & chemistry & physics': 0, + 'electronics & energy & mechanical eng.': 0, + 'geography & earth science & agriculture': 0 + }, + 'math': { + 'geometry': 0, + 'numeric commonsense and calculation': 0, + 'statistical reasoning': 0 + }, + } + MMStar_counter = deepcopy(MMStar_score_l2) + logger = get_logger('Evaluation') + + data = load(eval_file) + lt = len(data) + lines = [data[i] for i in range(lt)] + for i in tqdm(range(len(lines))): + line = lines[i] + predict = str(line['prediction']) + answers = str(line['answer']) + category = str(line['category']) + l2_category = str(line['l2_category']) + MMStar_counter[category][l2_category] += 1 + + answer = answers.lower().strip().replace('\n', ' ') + predict = predict.lower().strip().replace('\n', ' ') + + try: + if answer == predict[0]: + MMStar_score_l2[category][l2_category] += 1 + elif predict[0] == '(' and answer == predict[1]: + MMStar_score_l2[category][l2_category] += 1 + elif predict[0:7] == 'option ' and answer == predict[7]: + MMStar_score_l2[category][l2_category] += 1 + elif predict[0:14] == 'the answer is ' and answer == predict[14]: + MMStar_score_l2[category][l2_category] += 1 + except Exception as e: + pass + + MMStar_score = {} + MMStar_score['final score'] = 0 + for k, v in MMStar_score_l2.items(): + MMStar_score[k] = 0 + for l2_k, l2_v in v.items(): + MMStar_score[f'{k}({l2_k})'] = float(l2_v) / \ + float(MMStar_counter[k][l2_k]) + MMStar_score[k] += l2_v + MMStar_score['final score'] += MMStar_score[k] + MMStar_score[k] = float(MMStar_score[k]) / 250.0 + MMStar_score['final score'] = float(MMStar_score['final score']) / 1500.0 + + score_pth = eval_file.replace('.jsonl', '_score.json') + dump(MMStar_score, score_pth) + logger.info( + f'MMStar_eval successfully finished evaluating {eval_file}, results saved in {score_pth}') + logger.info('Score: ') + for key, value in MMStar_score.items(): + logger.info('{}:{}'.format(key, value)) + + return MMStar_score diff --git a/ola_vlm/eval/mmstar/smp/__init__.py b/ola_vlm/eval/mmstar/smp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0dd0b4b63be558b255e1b0d755cd3b70d7423157 --- /dev/null +++ b/ola_vlm/eval/mmstar/smp/__init__.py @@ -0,0 +1,3 @@ +from .file import * +from .misc import * +from .log import * \ No newline at end of file diff --git a/ola_vlm/eval/mmstar/smp/__pycache__/__init__.cpython-310.pyc b/ola_vlm/eval/mmstar/smp/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..426196d0abe503ba7657471fd34b0071232a7040 Binary files /dev/null and b/ola_vlm/eval/mmstar/smp/__pycache__/__init__.cpython-310.pyc differ diff --git a/ola_vlm/eval/mmstar/smp/__pycache__/file.cpython-310.pyc b/ola_vlm/eval/mmstar/smp/__pycache__/file.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..398c8b58709c0b4df36aaf335e39c244ee20f47c Binary files /dev/null and b/ola_vlm/eval/mmstar/smp/__pycache__/file.cpython-310.pyc differ diff --git a/ola_vlm/eval/mmstar/smp/__pycache__/log.cpython-310.pyc b/ola_vlm/eval/mmstar/smp/__pycache__/log.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e6c2a9113fa1ecf1bf91d9db4ea796f6fc3f084 Binary files /dev/null and b/ola_vlm/eval/mmstar/smp/__pycache__/log.cpython-310.pyc differ diff --git a/ola_vlm/eval/mmstar/smp/__pycache__/misc.cpython-310.pyc b/ola_vlm/eval/mmstar/smp/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16a70f874e38d0a54af8e64e6294033c59bdbc3d Binary files /dev/null and b/ola_vlm/eval/mmstar/smp/__pycache__/misc.cpython-310.pyc differ diff --git a/ola_vlm/eval/mmstar/smp/__pycache__/vlm.cpython-310.pyc b/ola_vlm/eval/mmstar/smp/__pycache__/vlm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bd1aba92b034b3756bfcc74644d46837e74d598 Binary files /dev/null and b/ola_vlm/eval/mmstar/smp/__pycache__/vlm.cpython-310.pyc differ diff --git a/ola_vlm/eval/mmstar/smp/file.py b/ola_vlm/eval/mmstar/smp/file.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e67ea790659a167417e9d85794c4bc5b99d95b --- /dev/null +++ b/ola_vlm/eval/mmstar/smp/file.py @@ -0,0 +1,147 @@ +import csv +import hashlib +import json +import os +import os.path as osp +import pickle +import time + +import numpy as np +import pandas as pd + + +class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, + np.int16, np.int32, np.int64, np.uint8, + np.uint16, np.uint32, np.uint64)): + return int(obj) + elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): + return float(obj) + elif isinstance(obj, (np.complex_, np.complex64, np.complex128)): + return {'real': obj.real, 'imag': obj.imag} + elif isinstance(obj, (np.ndarray,)): + return obj.tolist() + elif isinstance(obj, (np.bool_)): + return bool(obj) + elif isinstance(obj, (np.void)): + return None + return json.JSONEncoder.default(self, obj) + +# LOAD & DUMP +def dump(data, f, **kwargs): + def dump_pkl(data, pth, **kwargs): + pickle.dump(data, open(pth, 'wb')) + + def dump_json(data, pth, **kwargs): + json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder) + + def dump_jsonl(data, f, **kwargs): + lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data] + with open(f, 'w', encoding='utf8') as fout: + fout.write('\n'.join(lines)) + + def dump_xlsx(data, f, **kwargs): + data.to_excel(f, index=False, engine='xlsxwriter') + + def dump_csv(data, f, quoting=csv.QUOTE_ALL): + data.to_csv(f, index=False, encoding='utf-8', quoting=quoting) + + def dump_tsv(data, f, quoting=csv.QUOTE_ALL): + data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting) + + handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv) + suffix = f.split('.')[-1] + return handlers[suffix](data, f, **kwargs) + +def load(f): + def load_pkl(pth): + return pickle.load(open(pth, 'rb')) + + def load_json(pth): + return json.load(open(pth, 'r', encoding='utf-8')) + + def load_jsonl(f): + lines = open(f, encoding='utf-8').readlines() + lines = [x.strip() for x in lines] + if lines[-1] == '': + lines = lines[:-1] + data = [json.loads(x) for x in lines] + return data + + def load_xlsx(f): + return pd.read_excel(f) + + def load_csv(f): + return pd.read_csv(f) + + def load_tsv(f): + return pd.read_csv(f, sep='\t') + + handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv) + suffix = f.split('.')[-1] + return handlers[suffix](f) + +def download_file(url, filename=None): + import urllib.request + + from tqdm import tqdm + + class DownloadProgressBar(tqdm): + def update_to(self, b=1, bsize=1, tsize=None): + if tsize is not None: + self.total = tsize + self.update(b * bsize - self.n) + + if filename is None: + filename = url.split('/')[-1] + + with DownloadProgressBar(unit='B', unit_scale=True, + miniters=1, desc=url.split('/')[-1]) as t: + urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to) + return filename + +def ls(dirname='.', match='', mode='all', level=1): + if dirname == '.': + ans = os.listdir(dirname) + else: + ans = [osp.join(dirname, x) for x in os.listdir(dirname)] + assert mode in ['all', 'dir', 'file'] + assert level >= 1 and isinstance(level, int) + if level == 1: + ans = [x for x in ans if match in x] + if mode == 'dir': + ans = [x for x in ans if osp.isdir(x)] + elif mode == 'file': + ans = [x for x in ans if not osp.isdir(x)] + else: + ans = [x for x in ans if osp.isdir(x)] + res = [] + for d in ans: + res.extend(ls(d, match=match, mode=mode, level=level-1)) + ans = res + return ans + +def mrlines(fname, sp='\n'): + f = open(fname).read().split(sp) + while f != [] and f[-1] == '': + f = f[:-1] + return f + +def mwlines(lines, fname): + with open(fname, 'w') as fout: + fout.write('\n'.join(lines)) + +def md5(file_pth): + with open(file_pth, 'rb') as f: + hash = hashlib.new('md5') + for chunk in iter(lambda: f.read(2**20), b''): + hash.update(chunk) + return str(hash.hexdigest()) + +def last_modified(pth): + stamp = osp.getmtime(pth) + m_ti = time.ctime(stamp) + t_obj = time.strptime(m_ti) + t = time.strftime('%Y%m%d%H%M%S', t_obj)[2:] + return t diff --git a/ola_vlm/eval/mmstar/smp/log.py b/ola_vlm/eval/mmstar/smp/log.py new file mode 100644 index 0000000000000000000000000000000000000000..53194dd995965455320d4c271713a9bd4465f43a --- /dev/null +++ b/ola_vlm/eval/mmstar/smp/log.py @@ -0,0 +1,43 @@ +import logging + +logger_initialized = {} + +def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + try: + import torch.distributed as dist + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + except ImportError: + rank = 0 + + if rank == 0 and log_file is not None: + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + if rank == 0: + logger.setLevel(log_level) + else: + logger.setLevel(logging.ERROR) + + logger_initialized[name] = True + return logger \ No newline at end of file diff --git a/ola_vlm/eval/mmstar/smp/misc.py b/ola_vlm/eval/mmstar/smp/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..c775c6f3f4725dd09cc624c5b7792d5f4758b2b4 --- /dev/null +++ b/ola_vlm/eval/mmstar/smp/misc.py @@ -0,0 +1,174 @@ +# flake8: noqa: F401, F403 +import abc +import argparse +import copy as cp +import csv +import datetime +import multiprocessing as mp +import os +import os.path as osp +import random as rd +import shutil +import subprocess +import warnings +from collections import OrderedDict, defaultdict +from multiprocessing import Pool, current_process + +import matplotlib.pyplot as plt +import pandas as pd +import requests +import seaborn as sns +from huggingface_hub import scan_cache_dir +from sty import bg, ef, fg, rs +from tabulate import tabulate, tabulate_formats +from tqdm import tqdm + + +def process_punctuation(inText): + import re + outText = inText + punct = [ + ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', + '>', '<', '@', '`', ',', '?', '!' + ] + commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605 + periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605 + for p in punct: + if (p + ' ' in inText or ' ' + p in inText) or (re.search( + commaStrip, inText) is not None): + outText = outText.replace(p, '') + else: + outText = outText.replace(p, ' ') + outText = periodStrip.sub('', outText, re.UNICODE) + return outText + + +def h2r(value): + if value[0] == '#': + value = value[1:] + assert len(value) == 6 + return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2)) + + +def r2h(rgb): + return '#%02x%02x%02x' % rgb + + +def colored(s, color): + if isinstance(color, str): + if hasattr(fg, color): + return getattr(fg, color) + s + fg.rs + color = h2r(color) + return fg(*color) + s + fg.rs + + +def istype(s, type): + if isinstance(s, type): + return True + try: + return isinstance(eval(s), type) + except Exception as _: + return False + + +def bincount(lst): + bins = defaultdict(lambda: 0) + for item in lst: + bins[item] += 1 + return bins + + +def get_cache_path(repo_id): + hf_cache_info = scan_cache_dir() + repos = list(hf_cache_info.repos) + repo = None + for r in repos: + if r.repo_id == repo_id: + repo = r + break + if repo is None: + return None + revs = list(repo.revisions) + rev2keep, last_modified = None, 0 + for rev in revs: + if rev.last_modified > last_modified: + rev2keep, last_modified = rev, rev.last_modified + if rev2keep is None: + return None + return str(rev2keep.snapshot_path) + + +def proxy_set(s): + import os + for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']: + os.environ[key] = s + + +def get_rank_and_world_size(): + local_rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + return local_rank, world_size + + +def get_local_rank_and_world_size(): + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + return local_rank, world_size + + +def splitlen(s, sym='/'): + return len(s.split(sym)) + + +def listinstr(lst, s): + assert isinstance(lst, list) + for item in lst: + if item in s: + return True + return False + + +def d2df(D): + return pd.DataFrame({x: [D[x]] for x in D}) + + +def cn_string(s): + import re + if re.search(u'[\u4e00-\u9fff]', s): + return True + return False + + +try: + import decord +except ImportError: + pass + + +def timestr(second=True, minute=False): + s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:] + if second: + return s + elif minute: + return s[:-2] + else: + return s[:-4] + + +def dict_merge(dct, merge_dct): + for k, _ in merge_dct.items(): + if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): # noqa + dict_merge(dct[k], merge_dct[k]) + else: + dct[k] = merge_dct[k] + + +def youtube_dl(idx): + cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4' + os.system(cmd) + + +def run_command(cmd): + if isinstance(cmd, str): + cmd = cmd.split() + return subprocess.check_output(cmd) diff --git a/ola_vlm/eval/model_cvbench_loader.py b/ola_vlm/eval/model_cvbench_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0e729996baa932691df8ee0efa3d0aa0a4be2f --- /dev/null +++ b/ola_vlm/eval/model_cvbench_loader.py @@ -0,0 +1,166 @@ +import argparse +import torch +import os +import json +from tqdm import tqdm +import shortuuid + +from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from ola_vlm.conversation import conv_templates, SeparatorStyle +from ola_vlm.model.builder import load_pretrained_model +from ola_vlm.utils import disable_torch_init +from ola_vlm.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path +from torch.utils.data import Dataset, DataLoader +from datasets import load_dataset +from PIL import Image +import math + + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + +def load_jsonl(f): + lines = open(f, encoding='utf-8').readlines() + lines = [x.strip() for x in lines] + if lines[-1] == '': + lines = lines[:-1] + data = [json.loads(x) for x in lines] + return data + +def prepare_CVBench(path): + dataset = load_jsonl(os.path.join(path, 'test.jsonl')) + data = [] + for i in range(len(dataset)): + d = { + "image": os.path.join(path, dataset[i]["filename"]), + "question": dataset[i]["prompt"] + "\nOnly answer the option as the output. For example, if your answer is the option A, answer (A).", + "answer": dataset[i]["answer"], + "task": dataset[i]["task"], + "source": dataset[i]["source"] + } + data.append(d) + return data + + +# Custom dataset class +class CustomDataset(Dataset): + def __init__(self, data, tokenizer, image_processor, model_config): + self.questions = data + self.tokenizer = tokenizer + self.image_processor = image_processor + self.model_config = model_config + + def __getitem__(self, index): + d = self.questions[index] + qs = d["question"] + image_file = d["image"] + ans = d["answer"] + task = d["task"] + source = d["source"] + + if self.model_config.mm_use_im_start_end: + qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs + else: + qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + image = Image.open(image_file).convert('RGB') + image_tensor = process_images([image], self.image_processor, self.model_config)[0] + + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') + + return input_ids, image_tensor, image.size, ans, task, source + + def __len__(self): + return len(self.questions) + + +def collate_fn(batch): + input_ids, image_tensors, image_sizes, answers, cats, cats_l2 = zip(*batch) + input_ids = torch.stack(input_ids, dim=0) + image_tensors = torch.stack(image_tensors, dim=0) + return input_ids, image_tensors, image_sizes, answers, cats, cats_l2 + + +# DataLoader +def create_data_loader(questions, tokenizer, image_processor, model_config, batch_size=1, num_workers=4): + assert batch_size == 1, "batch_size must be 1" + dataset = CustomDataset(questions, tokenizer, image_processor, model_config) + data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn) + return data_loader + + +def eval_model(args): + # Model + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + model_name = get_model_name_from_path(model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) + + questions = prepare_CVBench(args.path) + questions = get_chunk(questions, args.num_chunks, args.chunk_idx) + answers_file = os.path.expanduser(args.answers_file) + os.makedirs(os.path.dirname(answers_file), exist_ok=True) + ans_file = open(answers_file, "w") + + if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode: + args.conv_mode = args.conv_mode + '_mmtag' + print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.') + + data_loader = create_data_loader(questions, tokenizer, image_processor, model.config) + + for (input_ids, image_tensor, image_sizes, answer, task, source), line in tqdm(zip(data_loader, questions), total=len(questions)): + input_ids = input_ids.to(device='cuda', non_blocking=True) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), + image_sizes=image_sizes, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + max_new_tokens=args.max_new_tokens, + use_cache=True) + + if not isinstance(output_ids, torch.Tensor): + output_ids = output_ids.sequences + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + + ans_file.write(json.dumps({"prediction": outputs, + "answer": answer, + "question": line, + "source": source, + "task": task}) + "\n") + # ans_file.flush() + ans_file.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--path", type=str, default="CV-Bench") + parser.add_argument("--answers-file", type=str, default="cv-bench_answer.jsonl") + parser.add_argument("--conv-mode", type=str, default="llava_phi_3") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--num_beams", type=int, default=1) + parser.add_argument("--max_new_tokens", type=int, default=128) + args = parser.parse_args() + + eval_model(args) diff --git a/ola_vlm/eval/model_mmstar_loader.py b/ola_vlm/eval/model_mmstar_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..1da982e4f3aea5683ca775a1cebe8bc86b65c117 --- /dev/null +++ b/ola_vlm/eval/model_mmstar_loader.py @@ -0,0 +1,164 @@ +import argparse +import torch +import os +import json +from tqdm import tqdm +import shortuuid + +from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from ola_vlm.conversation import conv_templates, SeparatorStyle +from ola_vlm.model.builder import load_pretrained_model +from ola_vlm.utils import disable_torch_init +from ola_vlm.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path +from torch.utils.data import Dataset, DataLoader +from datasets import load_dataset +from PIL import Image +import math + + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + + +def prepare_MMStar(path): + os.makedirs(f"{path}/images", exist_ok=True) + dataset = load_dataset(path, "val") + dataset = dataset["val"] + data = [] + for i in range(len(dataset)): + if not os.path.exists(f"{path}/images/{i}.jpeg"): + dataset[i]["image"].save(f"{path}/images/{i}.jpeg") + prompt = dataset[i]["question"] + "\n" + prompt += "Answer with the option's letter from the given choices directly, such as answer letter 'A' only. \n" + + d = { + "image": f"{path}/images/{i}.jpeg", + "question": prompt, + "answer": dataset[i]["answer"], + "category": dataset[i]["category"], + "l2_category": dataset[i]["l2_category"] + } + data.append(d) + return data + + +# Custom dataset class +class CustomDataset(Dataset): + def __init__(self, data, tokenizer, image_processor, model_config): + self.questions = data + self.tokenizer = tokenizer + self.image_processor = image_processor + self.model_config = model_config + + def __getitem__(self, index): + d = self.questions[index] + qs = d["question"] + image_file = d["image"] + ans = d["answer"] + + if self.model_config.mm_use_im_start_end: + qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs + else: + qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + image = Image.open(image_file).convert('RGB') + image_tensor = process_images([image], self.image_processor, self.model_config)[0] + + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') + + return input_ids, image_tensor, image.size, ans, d["category"], d["l2_category"] + + def __len__(self): + return len(self.questions) + + +def collate_fn(batch): + input_ids, image_tensors, image_sizes, answers, cats, cats_l2 = zip(*batch) + input_ids = torch.stack(input_ids, dim=0) + image_tensors = torch.stack(image_tensors, dim=0) + return input_ids, image_tensors, image_sizes, answers, cats, cats_l2 + + +# DataLoader +def create_data_loader(questions, tokenizer, image_processor, model_config, batch_size=1, num_workers=4): + assert batch_size == 1, "batch_size must be 1" + dataset = CustomDataset(questions, tokenizer, image_processor, model_config) + data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn) + return data_loader + + +def eval_model(args): + # Model + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + model_name = get_model_name_from_path(model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) + + questions = prepare_MMStar(args.path) + questions = get_chunk(questions, args.num_chunks, args.chunk_idx) + answers_file = os.path.expanduser(args.answers_file) + os.makedirs(os.path.dirname(answers_file), exist_ok=True) + ans_file = open(answers_file, "w") + + if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode: + args.conv_mode = args.conv_mode + '_mmtag' + print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.') + + data_loader = create_data_loader(questions, tokenizer, image_processor, model.config) + + for (input_ids, image_tensor, image_sizes, answer, cat, cat_l2), line in tqdm(zip(data_loader, questions), total=len(questions)): + input_ids = input_ids.to(device='cuda', non_blocking=True) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), + image_sizes=image_sizes, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + max_new_tokens=args.max_new_tokens, + use_cache=True) + + if not isinstance(output_ids, torch.Tensor): + output_ids = output_ids.sequences + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + + ans_file.write(json.dumps({"prediction": outputs, + "answer": answer[0], + "question": line, + "category": cat[0], + "l2_category": cat_l2[0]}) + "\n") + # ans_file.flush() + ans_file.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--path", type=str, default="MMStar") + parser.add_argument("--answers-file", type=str, default="mmstar_answer.jsonl") + parser.add_argument("--conv-mode", type=str, default="llava_phi_3") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--num_beams", type=int, default=1) + parser.add_argument("--max_new_tokens", type=int, default=128) + args = parser.parse_args() + + eval_model(args) diff --git a/ola_vlm/mm_utils.py b/ola_vlm/mm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9aab52fa29fe746490738f083770a21ba045268e --- /dev/null +++ b/ola_vlm/mm_utils.py @@ -0,0 +1,398 @@ +from PIL import Image +from io import BytesIO +import base64 +import torch +import math +import ast +import re +from transformers import StoppingCriteria +from ola_vlm.constants import IMAGE_TOKEN_INDEX + +########################################### + +def resize_and_center_crop(image, shortest_edge_length): + # Calculate new dimensions and resize + aspect_ratio = float(image.width) / float(image.height) + if aspect_ratio > 1: + new_width = int(shortest_edge_length * aspect_ratio) + new_height = shortest_edge_length + else: + new_width = shortest_edge_length + new_height = int(shortest_edge_length / aspect_ratio) + resized_image = image.resize((new_width, new_height), Image.ANTIALIAS) + + # Calculate the position and perform the center crop + left = (new_width - shortest_edge_length) / 2 + top = (new_height - shortest_edge_length) / 2 + right = (new_width + shortest_edge_length) / 2 + bottom = (new_height + shortest_edge_length) / 2 + cropped_image = resized_image.crop((left, top, right, bottom)) + + return cropped_image + + +def auto_pad_images(image, grid_params): + assert isinstance(image, Image.Image), "Input should be a Pillow Image" + assert len(grid_params) > 0, "Grid parameters should not be empty" + + # Step 1: Calculate and find the closest aspect ratio + input_width, input_height = image.size + input_aspect_ratio = input_width / input_height + candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params] + closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0])) + + candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3] + + target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1)) + + resize_width, resize_height = target_resolution + if input_width > input_height: + resize_height = int(resize_width / input_aspect_ratio) + else: + resize_width = int(resize_height * input_aspect_ratio) + resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS) + + # Step 5: Pad the resized image if necessary to match the target resolution + pad_width = target_resolution[0] - resize_width + pad_height = target_resolution[1] - resize_height + padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0)) + padded_image.paste(resized_image, (pad_width // 2, pad_height // 2)) + + return padded_image + + +def extract_patches(image, patch_size, overlap_ratio): + assert isinstance(image, Image.Image), "Input should be a Pillow Image" + assert patch_size > 0, "Patch size should be greater than 0" + assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1" + + W, H = image.size + patches = [] + + stride = int(patch_size * (1 - overlap_ratio)) + + num_patches_y = (H - patch_size) // stride + 1 + num_patches_x = (W - patch_size) // stride + 1 + + y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2 + x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2 + + for y in range(y_start, y_start + num_patches_y * stride, stride): + for x in range(x_start, x_start + num_patches_x * stride, stride): + patch = image.crop((x, y, x + patch_size, y + patch_size)) + patches.append(patch) + + return patches + + +def process_highres_image_crop_split(image, data_args, processor=None): + crop_resolution = data_args.image_crop_resolution + split_resolution = data_args.image_split_resolution + if processor is None: + processor = data_args.image_processor + image_crop = resize_and_center_crop(image, crop_resolution) + image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0) + image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches] + return torch.stack(image_patches, dim=0) + + +def process_highres_image(image, processor, grid_pinpoints): + grid_params = [int(x) for x in grid_pinpoints.split(",")] + width_height = max(image.size) + fit_grid_params = [x for x in grid_params if x >= width_height] + if len(fit_grid_params) == 0: + select_size = max(grid_params) + else: + select_size = min(fit_grid_params) + # FIXME: always select the 448 + select_size = max(grid_params) + image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) + + # FIXME: this seems to be a bug that it always resizes instead of padding + image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"])) + image_padded = image_padded.resize((select_size, select_size)) + image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0) + image_patches = [image_original_resize] + image_patches + image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches] + return torch.stack(image_patches, dim=0) + +######################################## + +def select_best_resolution(original_size, possible_resolutions): + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float('inf') + + for width, height in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + +def resize_and_pad_image(image, target_resolution): + """ + Resize and pad an image to a target resolution while maintaining aspect ratio. + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ + original_width, original_height = image.size + target_width, target_height = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + # Resize the image + resized_image = image.resize((new_width, new_height)) + + new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + new_image.paste(resized_image, (paste_x, paste_y)) + + return new_image + + +def divide_to_patches(image, patch_size): + """ + Divides an image into patches of a specified size. + + Args: + image (PIL.Image.Image): The input image. + patch_size (int): The size of each patch. + + Returns: + list: A list of PIL.Image.Image objects representing the patches. + """ + patches = [] + width, height = image.size + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + box = (j, i, j + patch_size, i + patch_size) + patch = image.crop(box) + patches.append(patch) + + return patches + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (tuple): The size of the input image in the format (width, height). + grid_pinpoints (str): A string representation of a list of possible resolutions. + patch_size (int): The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + width, height = select_best_resolution(image_size, possible_resolutions) + return width // patch_size, height // patch_size + + +def process_anyres_image(image, processor, grid_pinpoints): + """ + Process an image with variable resolutions. + + Args: + image (PIL.Image.Image): The input image to be processed. + processor: The image processor object. + grid_pinpoints (str): A string representation of a list of possible resolutions. + + Returns: + torch.Tensor: A tensor containing the processed image patches. + """ + # Convert grid_pinpoints from string to list + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + try: + patch_size = processor.size[0] + except Exception as e: + patch_size = processor.size["shortest_edge"] + assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + best_resolution = select_best_resolution(image.size, possible_resolutions) + image_padded = resize_and_pad_image(image, best_resolution) + + patches = divide_to_patches(image_padded, processor.crop_size["height"]) + + # FIXME: this seems to be a bug that it resizes instead of pad. + # but to keep it consistent with previous, i will keep it as it is + # TODO: uncomment below to ablate with the padding + if isinstance(processor.size, dict): + shortest_edge = processor.size["shortest_edge"] + else: + shortest_edge = min(processor.size) + image_original_resize = image.resize((shortest_edge, shortest_edge)) + # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) + # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) + + image_patches = [image_original_resize] + patches + image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches] + return torch.stack(image_patches, dim=0) + + +def load_image_from_base64(image): + return Image.open(BytesIO(base64.b64decode(image))) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def process_images(images, image_processor, model_cfg): + image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) + new_images = [] + if image_aspect_ratio == "highres": + for image in images: + image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints) + new_images.append(image) + elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: + for image in images: + image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) + new_images.append(image) + elif image_aspect_ratio == "crop_split": + for image in images: + image = process_highres_image_crop_split(image, model_cfg, image_processor) + new_images.append(image) + elif image_aspect_ratio == "pad": + for image in images: + image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean)) + image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + new_images.append(image) + else: + return image_processor.preprocess(images, return_tensors="pt")["pixel_values"] + if all(x.shape == new_images[0].shape for x in new_images): + new_images = torch.stack(new_images, dim=0) + return new_images + + +def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + self.max_keyword_len = 0 + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: + cur_keyword_ids = cur_keyword_ids[1:] + if len(cur_keyword_ids) > self.max_keyword_len: + self.max_keyword_len = len(cur_keyword_ids) + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) + self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] + for keyword_id in self.keyword_ids: + truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] + if torch.equal(truncated_output_ids, keyword_id): + return True + outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + outputs = [] + for i in range(output_ids.shape[0]): + outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) + return all(outputs) diff --git a/ola_vlm/model/.DS_Store b/ola_vlm/model/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..29a0c2ff5f6400bac7d946f75f64df469cb2b7f7 Binary files /dev/null and b/ola_vlm/model/.DS_Store differ diff --git a/ola_vlm/model/__init__.py b/ola_vlm/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b028ced93f9ff8f2e8fabcd284080ffbc4f0aad --- /dev/null +++ b/ola_vlm/model/__init__.py @@ -0,0 +1,5 @@ +from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig +from .language_model.llava_phi3 import LlavaPhi3ForCausalLM, LlavaPhi3Config +from .language_model.ola_llama import OlaLlavaLlamaForCausalLM, OlaLlavaLlamaConfig +from .language_model.ola_phi3 import OlaLlavaPhi3ForCausalLM, OlaLlavaPhi3Config +from .language_model.probe_llava_llama import ProbeDSGLlavaLlamaForCausalLM, ProbeDSGLlavaLlamaConfig \ No newline at end of file diff --git a/ola_vlm/model/apply_delta.py b/ola_vlm/model/apply_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..666dd9691bde7d54ddf2871e311d6f621e29f099 --- /dev/null +++ b/ola_vlm/model/apply_delta.py @@ -0,0 +1,48 @@ +""" +Usage: +python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM +from llava import LlavaLlamaForCausalLM + + +def apply_delta(base_model_path, target_model_path, delta_path): + print("Loading base model") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Loading delta") + delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) + + print("Applying delta") + for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): + if name not in base.state_dict(): + assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' + continue + if param.data.shape == base.state_dict()[name].shape: + param.data += base.state_dict()[name] + else: + assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ + f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' + bparam = base.state_dict()[name] + param.data[:bparam.shape[0], :bparam.shape[1]] += bparam + + print("Saving target model") + delta.save_pretrained(target_model_path) + delta_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + + args = parser.parse_args() + + apply_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/ola_vlm/model/aux_heads/.DS_Store b/ola_vlm/model/aux_heads/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/ola_vlm/model/aux_heads/.DS_Store differ diff --git a/ola_vlm/model/aux_heads/__init__.py b/ola_vlm/model/aux_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c01247a23eae57a4d09560be5d4fc1e97ea86c70 --- /dev/null +++ b/ola_vlm/model/aux_heads/__init__.py @@ -0,0 +1,3 @@ +from .da_v2_head import DepthHead, DAv2_Head, DepthProbeHead, TaskTokenDepthHead +from .oneformer_head import OneFormerSegHead, OneFormerTaskTokenSegHead +from .gen_head import GenHead, TaskTokenGenHead \ No newline at end of file diff --git a/ola_vlm/model/aux_heads/da_v2_head.py b/ola_vlm/model/aux_heads/da_v2_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8d97c59437390ef1a0e3a4aab186a9fe4cf69e15 --- /dev/null +++ b/ola_vlm/model/aux_heads/da_v2_head.py @@ -0,0 +1,457 @@ +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from ola_vlm.model.multimodal_projector.resampler import Resampler, TaskTokenResampler + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + + output = self.out_conv(output) + + return output + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + return out + + +class DAv2_Head(nn.Module): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False + ): + super(DAv2_Head, self).__init__() + + self.embd_dims = { + 'vits': 1024, + 'vitb': 1024, + 'vitl': 1024, + 'vitg': 1024, + } + + self.depth_head = DPTHead(self.embd_dims[encoder], features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + def forward(self, features): + patch_h, patch_w = 336 // 14, 336 // 14 + depth = self.depth_head(features, patch_h, patch_w) + depth = F.relu(depth) + + return depth.squeeze(1) + + @torch.no_grad() + def infer_feats(self, feats, image_size=(336, 336)): + h, w = image_size + depth = self.forward(feats) + + depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0] + return depth.cpu().numpy() + +def build_mlp(in_hidden_size, hidden_size): + modules = [nn.Linear(in_hidden_size, hidden_size)] + modules.append(nn.ReLU()) + modules.append(nn.Linear(hidden_size, hidden_size)) + return nn.Sequential(*modules) + +def build_expand_mlp(in_hidden_size, hidden_size, out_size): + modules = [nn.Linear(in_hidden_size, hidden_size)] + modules.append(nn.ReLU()) + modules.append(nn.Linear(hidden_size, hidden_size)) + modules.append(nn.ReLU()) + modules.append(nn.Linear(hidden_size, out_size)) + return nn.Sequential(*modules) + +class DepthProbeHead(nn.Module): + def __init__( + self, + llm_hidden_size=4096, + proj_config=None, + ): + super(DepthProbeHead, self).__init__() + + self.linear_1 = build_mlp(llm_hidden_size, proj_config["output_dim"]) + self.linear_2 = build_mlp(llm_hidden_size, proj_config["output_dim"]) + self.linear_3 = build_mlp(llm_hidden_size, proj_config["output_dim"]) + self.linear_4 = build_mlp(llm_hidden_size, proj_config["output_dim"]) + + # self._init_weights() + + # def _init_weights(self): + # for m in self.modules(): + # if isinstance(m, nn.Linear): + # nn.init.xavier_uniform_(m.weight) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def forward(self, llm_feats): + + features = [(self.linear_1(llm_feats), None), + (self.linear_1(llm_feats), None), + (self.linear_2(llm_feats), None), + (self.linear_3(llm_feats), None) + ] + + return features + +class DepthHead(nn.Module): + def __init__( + self, + llm_hidden_size=4096, + proj_config=None, + use_intermediate_depth=False, + ): + super(DepthHead, self).__init__() + + self.projector = Resampler( + dim=proj_config["output_dim"], + depth=proj_config["depth"], + dim_head=proj_config["dim_head"], + heads=proj_config["num_heads"], + num_queries=proj_config["num_tokens"], + embedding_dim=llm_hidden_size, + output_dim=proj_config["output_dim"], + ff_mult=proj_config["ff_mult"], + ) + + self.use_intermediate_depth = use_intermediate_depth + + if self.use_intermediate_depth: + self.linear_1 = build_mlp(proj_config["output_dim"], proj_config["output_dim"]) + self.linear_2 = build_mlp(proj_config["output_dim"], proj_config["output_dim"]) + self.linear_3 = build_mlp(proj_config["output_dim"], proj_config["output_dim"]) + + def forward(self, llm_feats): + visual_feats = self.projector(llm_feats) + + features = [] + + if self.use_intermediate_depth: + features.append((self.linear_1(visual_feats), None)) + features.append((self.linear_2(visual_feats), None)) + features.append((self.linear_3(visual_feats), None)) + + features.append((visual_feats, None)) + + return features + +class TaskTokenDepthHead(nn.Module): + def __init__( + self, + proj_config=None, + llm_hidden_size=4096, + use_intermediate_depth=False, + ): + super(TaskTokenDepthHead, self).__init__() + + self.projector = TaskTokenResampler( + dim=llm_hidden_size, + depth=proj_config["depth"], + dim_head=proj_config["dim_head"], + heads=proj_config["num_heads"], + num_queries=proj_config["num_tokens"], + embedding_dim=llm_hidden_size, + output_dim=proj_config["output_dim"], + ff_mult=proj_config["ff_mult"], + ) + self.use_intermediate_depth = use_intermediate_depth + + if self.use_intermediate_depth: + self.linear_1 = build_mlp(proj_config["output_dim"], proj_config["output_dim"]) + self.linear_2 = build_mlp(proj_config["output_dim"], proj_config["output_dim"]) + self.linear_3 = build_mlp(proj_config["output_dim"], proj_config["output_dim"]) + + def forward(self, llm_feats, latents): + + visual_feats = self.projector(llm_feats, latents) + + features = [] + + if self.use_intermediate_depth: + features.append((self.linear_1(visual_feats), None)) + features.append((self.linear_2(visual_feats), None)) + features.append((self.linear_3(visual_feats), None)) + + features.append((visual_feats, None)) + + return features \ No newline at end of file diff --git a/ola_vlm/model/aux_heads/depth_anything_v2/dinov2.py b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..83d250818c721c6df3b30d3f4352945527701615 --- /dev/null +++ b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=self.interpolate_antialias + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def DINOv2(model_name): + model_zoo = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2 + } + + return model_zoo[model_name]( + img_size=518, + patch_size=14, + init_values=1.0, + ffn_layer="mlp" if model_name != "vitg" else "swiglufused", + block_chunks=0, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1 + ) diff --git a/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/__init__.py b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1 --- /dev/null +++ b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/attention.py b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..815a2bf53dbec496f6a184ed7d03bcecb7124262 --- /dev/null +++ b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/attention.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + \ No newline at end of file diff --git a/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/block.py b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1 --- /dev/null +++ b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/drop_path.py b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1 --- /dev/null +++ b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/layer_scale.py b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1 --- /dev/null +++ b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/mlp.py b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/patch_embed.py b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..d0881b3533fb0c74d46d0f5da9afee5c09ca8a9e --- /dev/null +++ b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/patch_embed.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = x.to(self.proj.bias.dtype) + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/swiglu_ffn.py b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/ola_vlm/model/aux_heads/depth_anything_v2/dinov2_layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/ola_vlm/model/aux_heads/depth_anything_v2/dpt.py b/ola_vlm/model/aux_heads/depth_anything_v2/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..84402c544c4b1ad3afd12595c7ad529d6e3923d0 --- /dev/null +++ b/ola_vlm/model/aux_heads/depth_anything_v2/dpt.py @@ -0,0 +1,219 @@ +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Compose + +from .dinov2 import DINOv2 +from .util.blocks import FeatureFusionBlock, _make_scratch +from .util.transform import Resize, NormalizeImage, PrepareForNet + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + return out + + +class DepthAnythingV2(nn.Module): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False + ): + super(DepthAnythingV2, self).__init__() + + self.intermediate_layer_idx = { + 'vits': [2, 5, 8, 11], + 'vitb': [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23], + 'vitg': [9, 19, 29, 39] + } + + self.encoder = encoder + self.pretrained = DINOv2(model_name=encoder) + + self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + def forward(self, x): + patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14 + features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True) + + return features + + @torch.no_grad() + def infer_image(self, raw_image, input_size=336, is_dsg=False): + image, (h, w) = self.image2tensor(raw_image, input_size) + + features = self.forward(image) + if is_dsg: + return features + # feats = torch.cat([f[0] for f in features], dim=2) + feats = features[-1][0] + + return feats + + def image2tensor(self, raw_image, input_size=518): + transform = Compose([ + Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + h, w = raw_image.shape[:2] + + image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 + + image = transform({'image': image})['image'] + image = torch.from_numpy(image).unsqueeze(0) + + DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + image = image.to(DEVICE) + + return image, (h, w) diff --git a/ola_vlm/model/aux_heads/depth_anything_v2/util/blocks.py b/ola_vlm/model/aux_heads/depth_anything_v2/util/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..382ea183a40264056142afffc201c992a2b01d37 --- /dev/null +++ b/ola_vlm/model/aux_heads/depth_anything_v2/util/blocks.py @@ -0,0 +1,148 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + + output = self.out_conv(output) + + return output diff --git a/ola_vlm/model/aux_heads/depth_anything_v2/util/transform.py b/ola_vlm/model/aux_heads/depth_anything_v2/util/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..b14aacd44ea086b01725a9ca68bb49eadcf37d73 --- /dev/null +++ b/ola_vlm/model/aux_heads/depth_anything_v2/util/transform.py @@ -0,0 +1,158 @@ +import numpy as np +import cv2 + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) + + # resize sample + sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) + + if self.__resize_target: + if "depth" in sample: + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) + + if "mask" in sample: + sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + return sample \ No newline at end of file diff --git a/ola_vlm/model/aux_heads/gen_head.py b/ola_vlm/model/aux_heads/gen_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9df59fa596c0a22cc710e4118fe75c36361f8342 --- /dev/null +++ b/ola_vlm/model/aux_heads/gen_head.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from ola_vlm.model.multimodal_projector.resampler import Resampler, TaskTokenResampler + + +class GenHead(nn.Module): + + def __init__( + self, + proj_config: dict = None, + llm_hidden_size: int = 4096, + ) -> None: + super().__init__() + + self.projector = Resampler( + dim=proj_config["output_dim"], + depth=proj_config["depth"], + dim_head=proj_config["dim_head"], + heads=proj_config["num_heads"], + num_queries=proj_config["num_tokens"], + embedding_dim=llm_hidden_size, + output_dim=proj_config["output_dim"], + ff_mult=proj_config["ff_mult"], + ) + + def forward( + self, + llm_feats: torch.Tensor, + ): + gen_feats = self.projector(llm_feats) + return gen_feats + +class TaskTokenGenHead(nn.Module): + + def __init__( + self, + proj_config: dict = None, + llm_hidden_size: int = 4096, + ) -> None: + super().__init__() + + self.projector = TaskTokenResampler( + dim=proj_config["output_dim"], + depth=proj_config["depth"], + dim_head=proj_config["dim_head"], + heads=proj_config["num_heads"], + num_queries=proj_config["num_tokens"], + embedding_dim=llm_hidden_size, + output_dim=proj_config["output_dim"], + ff_mult=proj_config["ff_mult"], + ) + + def forward( + self, + llm_feats: torch.Tensor, + latents: torch.Tensor + ): + gen_feats = self.projector(llm_feats, latents) + return gen_feats \ No newline at end of file diff --git a/ola_vlm/model/aux_heads/oneformer_head.py b/ola_vlm/model/aux_heads/oneformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8b2a3d5ec0ea3ce4ce53721602313d6ac3d98af9 --- /dev/null +++ b/ola_vlm/model/aux_heads/oneformer_head.py @@ -0,0 +1,264 @@ +import torch +from typing import Optional +from torch import Tensor, nn +from ola_vlm.model.multimodal_projector.resampler import Resampler, TaskTokenResampler +import math +from torch.nn import functional as F +from transformers import OneFormerModel +from transformers.models.oneformer.modeling_oneformer import OneFormerForUniversalSegmentationOutput, OneFormerModelOutput, OneFormerPixelLevelModule, OneFormerPixelLevelModuleOutput + + +class AuxOneFormerPixelLevelModule(OneFormerPixelLevelModule): + def __init__(self, config): + super().__init__(config) + + def forward(self, pixel_values: Tensor, output_hidden_states: bool = False, last_backbone_feats: Tensor = None, all_backbone_features: Tensor = None, return_features: bool = False, return_all_features: bool = False): + if all_backbone_features is None: + features = self.encoder(pixel_values).feature_maps + if return_all_features: + return features + else: + features = all_backbone_features + if last_backbone_feats is not None: + features = list(features) + last_backbone_feats = F.interpolate(last_backbone_feats, size=features[-1].shape[-2:], mode='bilinear', align_corners=False) + features[-1] = last_backbone_feats + for i in range(3): + features[i] = F.interpolate(features[i], size=features[-1].shape[-2:], mode='bilinear', align_corners=False) + features = tuple(features) + elif return_features: + return F.interpolate(features[-1], size=(24, 24), mode='bilinear', align_corners=False) + decoder_output = self.decoder(features, output_hidden_states=output_hidden_states) + return OneFormerPixelLevelModuleOutput( + encoder_features=tuple(features), + decoder_features=decoder_output.multi_scale_features, + decoder_last_feature=decoder_output.mask_features, + ) + +class OneFormerHead(OneFormerModel): + def __init__(self, config): + super().__init__(config) + self.pixel_level_module = AuxOneFormerPixelLevelModule(config) + + def forward_features( + self, + pixel_values: Tensor, + task_inputs: Tensor, + text_inputs: Tensor = None, + pixel_mask: Tensor = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, _, height, width = pixel_values.shape + + if pixel_mask is None: + pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) + + backbone_last_feature = self.pixel_level_module(pixel_values, output_hidden_states, return_features=True) + + return backbone_last_feature + + def get_backbone_feats( + self, + pixel_values: Tensor, + task_inputs: Tensor, + text_inputs: Tensor = None, + pixel_mask: Tensor = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, _, height, width = pixel_values.shape + + if pixel_mask is None: + pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) + + backbone_last_feature = self.pixel_level_module(pixel_values, output_hidden_states, return_all_features=True) + + return backbone_last_feature + + def get_masks( + self, + pixel_values: Tensor, + task_inputs: Tensor, + text_inputs: Tensor = None, + pixel_mask: Tensor = None, + backbone_last_feature: Tensor = None, + all_backbone_features: Tensor = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, _, height, width = pixel_values.shape + + if pixel_mask is None: + pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) + + pixel_level_module_output = self.pixel_level_module(pixel_values, output_hidden_states, backbone_last_feature, all_backbone_features) + + multi_scale_features = pixel_level_module_output.decoder_features + mask_features = pixel_level_module_output.decoder_last_feature + + task_token = self.task_encoder(task_inputs.to(self.dtype)) + + if self.is_training: + text_queries = self.text_mapper(text_inputs) + else: + text_queries = None + + transformer_module_output = self.transformer_module( + multi_scale_features=multi_scale_features, + mask_features=mask_features, + task_token=task_token, + output_attentions=output_attentions, + ) + + queries = transformer_module_output.object_queries + + encoder_hidden_states = None + pixel_decoder_hidden_states = None + transformer_decoder_hidden_states = None + + if output_hidden_states: + encoder_hidden_states = pixel_level_module_output.encoder_features + pixel_decoder_hidden_states = (pixel_level_module_output.decoder_last_feature,) + for f in pixel_level_module_output.decoder_features: + pixel_decoder_hidden_states += (f,) + transformer_decoder_hidden_states = transformer_module_output.auxiliary_predictions + + outputs = OneFormerModelOutput( + encoder_hidden_states=encoder_hidden_states, + pixel_decoder_hidden_states=pixel_decoder_hidden_states, + transformer_decoder_hidden_states=transformer_decoder_hidden_states, + transformer_decoder_object_queries=queries, + transformer_decoder_contrastive_queries=transformer_module_output.contrastive_logits, + transformer_decoder_mask_predictions=transformer_module_output.prediction_masks, + transformer_decoder_class_predictions=transformer_module_output.prediction_class, + transformer_decoder_auxiliary_predictions=transformer_module_output.auxiliary_predictions, + text_queries=text_queries, + task_token=task_token, + attentions=transformer_module_output.attentions, + ) + + class_queries_logits = outputs.transformer_decoder_class_predictions + masks_queries_logits = outputs.transformer_decoder_mask_predictions + contrastive_queries_logits = outputs.transformer_decoder_contrastive_queries + auxiliary_predictions = outputs.transformer_decoder_auxiliary_predictions + text_queries = outputs.text_queries + + output = OneFormerForUniversalSegmentationOutput( + class_queries_logits=class_queries_logits, + masks_queries_logits=masks_queries_logits, + auxiliary_predictions=auxiliary_predictions, + loss=None, + **outputs, + ) + + return output + +class OneFormerSegHead(nn.Module): + + def __init__( + self, + proj_config: dict = None, + llm_hidden_size: int = 4096, + ) -> None: + super().__init__() + + self.projector = Resampler( + dim=proj_config["output_dim"], + depth=proj_config["depth"], + dim_head=proj_config["dim_head"], + heads=proj_config["num_heads"], + num_queries=proj_config["num_tokens"], + embedding_dim=llm_hidden_size, + output_dim=proj_config["output_dim"], + ff_mult=proj_config["ff_mult"], + ) + + + def forward( + self, + llm_feats: torch.Tensor, + ): + visual_feats = self.projector(llm_feats) + b, n, c = visual_feats.shape + b = int(b) + c = int(c) + h = w = int(math.sqrt(int(n))) + visual_feats = visual_feats.permute(0, 2, 1) + image_embeddings = visual_feats.reshape(b, c, h, w) + + return image_embeddings + + +class OneFormerTaskTokenSegHead(nn.Module): + + def __init__( + self, + proj_config: dict = None, + llm_hidden_size: int = 4096, + ) -> None: + super().__init__() + + self.projector = TaskTokenResampler( + dim=proj_config["output_dim"], + depth=proj_config["depth"], + dim_head=proj_config["dim_head"], + heads=proj_config["num_heads"], + num_queries=proj_config["num_tokens"], + embedding_dim=llm_hidden_size, + output_dim=proj_config["output_dim"], + ff_mult=proj_config["ff_mult"], + ) + + + def forward( + self, + llm_feats: torch.Tensor, + latents: torch.Tensor, + ): + visual_feats = self.projector(llm_feats, latents) + b, n, c = visual_feats.shape + b = int(b) + c = int(c) + h = w = int(math.sqrt(int(n))) + visual_feats = visual_feats.permute(0, 2, 1) + image_embeddings = visual_feats.reshape(b, c, h, w) + + return image_embeddings + +def build_mlp(in_hidden_size, hidden_size): + modules = [nn.Linear(in_hidden_size, hidden_size)] + modules.append(nn.GELU()) + modules.append(nn.Linear(hidden_size, hidden_size)) + return nn.Sequential(*modules) \ No newline at end of file diff --git a/ola_vlm/model/builder.py b/ola_vlm/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..68bc023b6994ed614c85db6cd3c2237a14c8e749 --- /dev/null +++ b/ola_vlm/model/builder.py @@ -0,0 +1,287 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import warnings +import shutil + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig +import torch +from ola_vlm.model import * +from ola_vlm.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + + +def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): + kwargs = {"device_map": device_map, **kwargs} + + if device != "cuda": + kwargs['device_map'] = {"": device} + + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.float16 + + if use_flash_attn: + kwargs['attn_implementation'] = 'flash_attention_2' + + if 'llava' in model_name.lower() or 'clip' in model_name.lower() or 'sherlock' in model_name.lower() or 'dino' in model_name.lower(): + # Load LLaVA model + if 'lora' in model_name.lower() and model_base is None: + warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') + if 'lora' in model_name.lower() and model_base is not None: + from ola_vlm.model.language_model.llava_llama import LlavaConfig + lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + print('Loading LLaVA from base model...') + if "phi" in model_name.lower(): + model = LlavaPhi3ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) + elif "qwen" in model_name.lower(): + model = LlavaQwenForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) + else: + model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) + token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features + if model.lm_head.weight.shape[0] != token_num: + model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) + model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) + + print('Loading additional LLaVA weights...') + if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): + non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') + else: + # this is probably from HF Hub + from huggingface_hub import hf_hub_download + def load_from_hf(repo_id, filename, subfolder=None): + cache_file = hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder) + return torch.load(cache_file, map_location='cpu') + non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') + non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} + if any(k.startswith('model.model.') for k in non_lora_trainables): + non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} + model.load_state_dict(non_lora_trainables, strict=False) + + from peft import PeftModel + print('Loading LoRA weights...') + model = PeftModel.from_pretrained(model, model_path) + print('Merging LoRA weights...') + model = model.merge_and_unload() + print('Model is loaded...') + elif model_base is not None: + # this may be mm projector only + print('Loading LLaVA from base model...') + if "probe" in model_name.lower(): + cfg_pretrained = AutoConfig.from_pretrained(model_path) + if "phi" in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = ProbeDSGLlavaPhi3ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + elif "qwen" in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_base) + model = ProbeDSGLlavaQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = ProbeDSGLlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + elif 'mpt' in model_name.lower(): + if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): + shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) + cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + elif "phi" in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + if "dsg" in model_name.lower(): + model = OlaLlavaPhi3ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + elif "multi_enc" in model_name.lower(): + model = MultiEncLlavaPhi3ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + else: + model = LlavaPhi3ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + elif "qwen" in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_base) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + if "dsg" in model_name.lower(): + model = OlaLlavaQwenForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + else: + model = LlavaQwenForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + if "dsg" in model_name.lower(): + model = OlaLlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + elif "multi_enc" in model_name.lower(): + model = MultiEncLlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + else: + model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + + mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') + mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} + model.load_state_dict(mm_projector_weights, strict=False) + else: + if 'probe' in model_name.lower(): + if 'phi' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = ProbeDSGLlavaPhi3ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + elif 'qwen' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = ProbeDSGLlavaQwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = ProbeDSGLlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + elif 'mpt' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + elif "phi" in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + if "dsg" in model_name.lower(): + model = OlaLlavaPhi3ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + elif "multi_enc" in model_name.lower(): + model = MultiEncLlavaPhi3ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + else: + model = LlavaPhi3ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + elif "qwen" in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path) + if "dsg" in model_name.lower(): + model = OlaLlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + else: + model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + elif 'mistral' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = LlavaMistralForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **kwargs + ) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + if "dsg" in model_name.lower(): + model = OlaLlavaLlamaForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **kwargs + ) + elif "multi_enc" in model_name.lower(): + model = MultiEncLlavaLlamaForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **kwargs + ) + else: + model = LlavaLlamaForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **kwargs + ) + else: + # Load language model + if model_base is not None: + # PEFT model + from peft import PeftModel + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) + print(f"Loading LoRA weights from {model_path}") + model = PeftModel.from_pretrained(model, model_path) + print(f"Merging weights") + model = model.merge_and_unload() + print('Convert to FP16...') + model.to(torch.float16) + else: + use_fast = False + if 'mpt' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + + image_processor = None + + if 'llava' in model_name.lower() or 'sherlock' in model_name.lower() or 'probe' in model_name.lower() or 'clip' in model_name.lower() or 'dino' in model_name.lower(): + + if "convnext" in model_name.lower(): + model = reload_from_ckpt(model_path, model) + + vision_tower = model.get_vision_tower() + + if "multi_enc" in model_name.lower(): + model.get_model().init_encoders(model.config) + + if not vision_tower.is_loaded: + vision_tower.load_model(device_map=device_map) + + if device_map != 'auto': + vision_tower.to(device=device_map, dtype=torch.float16) + + try: + if vision_tower.device != model.device: + vision_tower = vision_tower.to(model.device) + except: + pass + + mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) + mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) + if mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + if mm_use_im_start_end: + tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + model.resize_token_embeddings(len(tokenizer)) + + image_processor = vision_tower.image_processor + + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 4096 + return tokenizer, model, image_processor, context_len + + +def reload_from_ckpt(model_path, model, cache_dir=None): + import os + from safetensors import safe_open + from huggingface_hub import hf_hub_download, list_repo_files + + state_dict = {} + + # Check if the path is a local directory or HF Hub model + if os.path.isdir(model_path): + # Local directory: Load safetensors files + safetensors_paths = [os.path.join(model_path, f) for f in os.listdir(model_path) if f.endswith('.safetensors')] + else: + # HF Hub: Get list of safetensors files and download them + repo_files = list_repo_files(model_path) + safetensors_paths = [ + hf_hub_download(model_path, file_name, cache_dir=cache_dir) + for file_name in repo_files if file_name.endswith('.safetensors') + ] + + # Load safetensors files into the state_dict + for path in safetensors_paths: + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + if "vision_tower" in key: + state_dict[key] = f.get_tensor(key) + + # Load the state dict into the model + model.load_state_dict(state_dict, strict=False) + return model \ No newline at end of file diff --git a/ola_vlm/model/consolidate.py b/ola_vlm/model/consolidate.py new file mode 100644 index 0000000000000000000000000000000000000000..2260a874708c9ff946fb2cac3fe9dff276f613ce --- /dev/null +++ b/ola_vlm/model/consolidate.py @@ -0,0 +1,29 @@ +""" +Usage: +python3 -m ola_vlm.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate +""" +import argparse + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from ola_vlm.model import * +from ola_vlm.model.utils import auto_upgrade + + +def consolidate_ckpt(src_path, dst_path): + print("Loading model") + auto_upgrade(src_path) + src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) + src_model.save_pretrained(dst_path) + src_tokenizer.save_pretrained(dst_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str, required=True) + parser.add_argument("--dst", type=str, required=True) + + args = parser.parse_args() + + consolidate_ckpt(args.src, args.dst) diff --git a/ola_vlm/model/language_model/base_lm.py b/ola_vlm/model/language_model/base_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..100a8d3bd72017165c798f7a73198ef1ef5c8ec8 --- /dev/null +++ b/ola_vlm/model/language_model/base_lm.py @@ -0,0 +1,859 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers.generation.utils import GenerateNonBeamOutput + +from transformers.utils import logging, is_accelerate_available +from transformers.generation.configuration_utils import GenerationConfig +from transformers.generation.logits_process import ( + LogitsProcessorList, +) +from transformers.generation.streamers import BaseStreamer +from transformers.generation.stopping_criteria import ( + StoppingCriteriaList, +) +from transformers.utils import ModelOutput, logging + +import os + +logger = logging.get_logger(__name__) + +import collections +import gc +import itertools +import os +import re +import shutil +import tempfile + +from transformers import PreTrainedModel + +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.pytorch_utils import id_tensor_storage +from transformers.modeling_utils import ( + is_fsdp_enabled, is_local_dist_rank_0, + load_state_dict, set_initialized_submodules, + _load_state_dict_into_model, + _load_state_dict_into_meta_model, + expand_device_map, get_disk_only_shard_files, + get_disk_only_shard_files, +) + +if is_accelerate_available(): + from accelerate.utils import ( + find_tied_parameters, + load_offloaded_weights, + save_offload_index, + set_module_tensor_to_device, + ) + +from transformers.utils import logging +from dataclasses import dataclass + +PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning." + +@dataclass +class GenerateDecoderOnlyOutput(ModelOutput): + """ + Outputs of decoder-only generation models, when using non-beam methods. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. + Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None + +def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): + # Convert old format to new format if needed from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key and ("vision_tower.vision_tower" not in key and "dav2_model" not in key): + logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight")) + new_key = key.replace("gamma", "weight") + if "beta" in key and "vision_tower.vision_tower" not in key: + logger.warning(PARAM_RENAME_WARNING.format("beta", "bias")) + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + local_metadata["assign_to_params_buffers"] = assign_to_params_buffers + + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + if is_deepspeed_zero3_enabled(): + import deepspeed + + # In sharded models, each shard has only part of the full state_dict, so only gather + # parameters that are in the current state_dict. + named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) + params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] + if len(params_to_gather) > 0: + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + else: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".", assign_to_params_buffers) + + load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) + # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so + # it's safe to delete it. + del state_dict + + return error_msgs + +def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): + """ + Checks if `model_to_load` supports param buffer assignment (such + as when loading in empty weights) by first checking + if the model explicitly disables it, then by ensuring that the state dict keys + are a subset of the model's parameters. + + Note: We fully disable this if we are using `deepspeed` + """ + if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: + return False + + if is_deepspeed_zero3_enabled(): + return False + + # Some models explicitly do not support param buffer assignment + if not getattr(model_to_load, "_supports_param_buffer_assignment", True): + logger.debug( + f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" + ) + return False + + # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype + first_key = list(model_to_load.state_dict().keys())[0] + if start_prefix + first_key in state_dict: + return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype + + # For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`) + return False + + +class BaseCausalLM(PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + logits_warper: Optional[LogitsProcessorList] = None, + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in + `generation_config`) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config.pad_token_id + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): + raise ValueError( + "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " + f"{logits_warper})." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size = input_ids.shape[0] + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + if do_sample: + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + probs = nn.functional.softmax(next_token_scores, dim=-1) + # token selection + if do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + loaded_keys, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + sharded_metadata=None, + _fast_init=True, + low_cpu_mem_usage=False, + device_map=None, + offload_folder=None, + offload_state_dict=None, + dtype=None, + hf_quantizer=None, + keep_in_fp32_modules=None, + gguf_path=None, + ): + is_safetensors = False + is_quantized = hf_quantizer is not None + state_dict_folder = None + state_dict_index = None + + if device_map is not None and "disk" in device_map.values(): + archive_file = ( + resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file + ) + is_safetensors = archive_file.endswith(".safetensors") + if offload_folder is None and not is_safetensors: + raise ValueError( + "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" + " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" + " offers the weights in this format." + ) + if offload_folder is not None: + os.makedirs(offload_folder, exist_ok=True) + if offload_state_dict is None: + offload_state_dict = True + + is_sharded_safetensors = is_safetensors and sharded_metadata is not None + + for key, param in model.state_dict().items(): + if param.device == torch.device("meta"): + try: + set_module_tensor_to_device( + model, key, "cuda", torch.empty(*param.size(), dtype=dtype) + ) + except: + pass + + # tie the model weights before retrieving the state_dict + model.tie_weights() + + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + expected_keys = list(model_state_dict.keys()) + prefix = model.base_model_prefix + + def _fix_key(key): + if "beta" in key and "vision_tower.vision_tower" not in key: + return key.replace("beta", "bias") + if "gamma" in key and ("vision_tower.vision_tower" not in key and "dav2_model" not in key): + return key.replace("gamma", "weight") + return key + + original_loaded_keys = loaded_keys + loaded_keys = [_fix_key(key) for key in loaded_keys] + + if len(prefix) > 0: + has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) + expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) + else: + has_prefix_module = False + expects_prefix_module = False + + # key re-naming operations are never done on the keys + # that are loaded, but always on the keys of the newly initialized model + remove_prefix_from_model = not has_prefix_module and expects_prefix_module + add_prefix_to_model = has_prefix_module and not expects_prefix_module + + if remove_prefix_from_model: + _prefix = f"{prefix}." + expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)] + expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys] + elif add_prefix_to_model: + expected_keys = [".".join([prefix, s]) for s in expected_keys] + + missing_keys = sorted(set(expected_keys) - set(loaded_keys)) + unexpected_keys = set(loaded_keys) - set(expected_keys) + + # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model + # buffers + model_buffers = {n for n, _ in model.named_buffers()} + if remove_prefix_from_model: + model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers} + elif add_prefix_to_model: + model_buffers = {".".join([prefix, key]) for key in model_buffers} + unexpected_keys = sorted(unexpected_keys - model_buffers) + + model.tie_weights() + if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): + ptrs = collections.defaultdict(list) + for name, tensor in model.state_dict().items(): + id_tensor = id_tensor_storage(tensor) + ptrs[id_tensor].append(name) + + # These are all the pointers of shared tensors. + tied_params = [names for _, names in ptrs.items() if len(names) > 1] + else: + # id function doesn't work for meta tensor so we need this function + tied_params = find_tied_parameters(model) + + for group in tied_params: + if remove_prefix_from_model: + group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group] + elif add_prefix_to_model: + group = [".".join([prefix, key]) for key in group] + missing_in_group = [k for k in missing_keys if k in group] + if len(missing_in_group) > 0 and len(missing_in_group) < len(group): + missing_keys = [k for k in missing_keys if k not in missing_in_group] + + # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if cls._keys_to_ignore_on_load_missing is not None: + for pat in cls._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + if hf_quantizer is not None: + missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) + + # retrieve weights on meta device and put them back on CPU. + # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step + if low_cpu_mem_usage: + for key in missing_keys: + if key in list(model_state_dict.keys()): + key = key + elif f"{prefix}.{key}" in list(model_state_dict.keys()): + key = f"{prefix}.{key}" + elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()): + key = ".".join(key.split(".")[1:]) + param = model_state_dict[key] + + # upcast in fp32 if any + target_dtype = dtype + if ( + keep_in_fp32_modules is not None + and dtype == torch.float16 + and any( + module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules + ) + ): + target_dtype = torch.float32 + + if param.device == torch.device("meta"): + value = torch.empty(*param.size(), dtype=target_dtype) + if ( + not is_quantized + or getattr(hf_quantizer, "requires_parameters_quantization", False) + or not hf_quantizer.check_quantized_param( + model, param_value=value, param_name=key, state_dict={} + ) + ): + set_module_tensor_to_device(model, key, "cpu", value) + else: + hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict, unexpected_keys) + + # retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights. + if _fast_init: + if not ignore_mismatched_sizes: + if remove_prefix_from_model: + _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] + elif add_prefix_to_model: + _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] + else: + _loaded_keys = loaded_keys + not_initialized_submodules = set_initialized_submodules(model, _loaded_keys) + # If we're about to tie the output embeds to the input embeds we don't need to init them + if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings: + output_embeddings = model.get_output_embeddings() + if output_embeddings is not None: + # Still need to initialize if there is a bias term since biases are not tied. + if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None: + output_embeddings._is_hf_initialized = True + else: + not_initialized_submodules = dict(model.named_modules()) + # This will only initialize submodules that are not marked as initialized by the line above. + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + not_initialized_parameters = list( + set( + itertools.chain.from_iterable( + submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values() + ) + ) + ) + with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): + model.apply(model._initialize_weights) + else: + model.apply(model._initialize_weights) + + # Set some modules to fp32 if any + if keep_in_fp32_modules is not None: + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): + # param = param.to(torch.float32) does not work here as only in the local scope. + param.data = param.data.to(torch.float32) + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = model + if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: + start_prefix = cls.base_model_prefix + "." + if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: + model_to_load = getattr(model, cls.base_model_prefix) + base_model_expected_keys = list(model_to_load.state_dict().keys()) + if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): + raise ValueError( + "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " + "properly saved?" + ) + if device_map is not None: + device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()} + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + # If the checkpoint is sharded, we may not have the key here. + if checkpoint_key not in state_dict: + continue + model_key = checkpoint_key + if remove_prefix_from_model: + # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. + model_key = f"{prefix}.{checkpoint_key}" + elif add_prefix_to_model: + # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. + model_key = ".".join(checkpoint_key.split(".")[1:]) + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + if ( + state_dict[checkpoint_key].shape[-1] == 1 + and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel() + ): + # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. + # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights. + pass + else: + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if resolved_archive_file is not None: + folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) + else: + folder = None + if device_map is not None and is_safetensors: + param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) + str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" + if sharded_metadata is None: + archive_file = ( + resolved_archive_file[0] + if isinstance(resolved_archive_file, (list, tuple)) + else resolved_archive_file + ) + weight_map = {p: archive_file for p in original_loaded_keys} + else: + weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} + offload_index = { + p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} + for p, f in weight_map.items() + if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" + } + else: + offload_index = None + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) + + # For GGUF models `state_dict` is never set to None as the state dict is always small + if gguf_path: + error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + loaded_keys, + start_prefix, + expected_keys, + device_map=device_map, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_folder=state_dict_folder, + state_dict_index=state_dict_index, + dtype=dtype, + hf_quantizer=hf_quantizer, + is_safetensors=is_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + ) + else: + # Sharded checkpoint or whole but low_cpu_mem_usage==True + assign_to_params_buffers = check_support_param_buffer_assignment( + model_to_load, state_dict, start_prefix + ) + error_msgs = _load_state_dict_into_model( + model_to_load, state_dict, start_prefix, assign_to_params_buffers + ) + + else: + # This should always be a list but, just to be sure. + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] + + error_msgs = [] + mismatched_keys = [] + if not is_safetensors: + offload_index = {} if device_map is not None and "disk" in device_map.values() else None + if offload_state_dict: + state_dict_folder = tempfile.mkdtemp() + state_dict_index = {} + else: + state_dict_folder = None + state_dict_index = None + + if is_sharded_safetensors: + disk_only_shard_files = get_disk_only_shard_files( + device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix + ) + disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] + else: + disk_only_shard_files = [] + + if len(resolved_archive_file) > 1: + resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") + assign_to_params_buffers = None + for shard_file in resolved_archive_file: + # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. + if shard_file in disk_only_shard_files: + continue + state_dict = load_state_dict(shard_file, is_quantized=is_quantized) + + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys += _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) + if low_cpu_mem_usage: + if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: + for key, param in model_to_load.state_dict().items(): + if param.device == torch.device("meta"): + set_module_tensor_to_device( + model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) + ) + else: + new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + loaded_keys, + start_prefix, + expected_keys, + device_map=device_map, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_folder=state_dict_folder, + state_dict_index=state_dict_index, + dtype=dtype, + hf_quantizer=hf_quantizer, + is_safetensors=is_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + ) + error_msgs += new_error_msgs + else: + # Sharded checkpoint or whole but low_cpu_mem_usage==True + if assign_to_params_buffers is None: + assign_to_params_buffers = check_support_param_buffer_assignment( + model_to_load, state_dict, start_prefix + ) + error_msgs += _load_state_dict_into_model( + model_to_load, state_dict, start_prefix, assign_to_params_buffers + ) + + # force memory release + del state_dict + gc.collect() + + if offload_index is not None and len(offload_index) > 0: + if model != model_to_load: + # We need to add the prefix of the base model + prefix = cls.base_model_prefix + if not is_safetensors: + for weight_name in offload_index: + shutil.move( + os.path.join(offload_folder, f"{weight_name}.dat"), + os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"), + ) + offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()} + if not is_safetensors: + save_offload_index(offload_index, offload_folder) + offload_index = None + + if offload_state_dict: + # Load back temporarily offloaded state dict + load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) + shutil.rmtree(state_dict_folder) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + archs = [] if model.config.architectures is None else model.config.architectures + warner = logger.warning if model.__class__.__name__ in archs else logger.info + warner( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs \ No newline at end of file diff --git a/ola_vlm/model/language_model/base_ola_vlm.py b/ola_vlm/model/language_model/base_ola_vlm.py new file mode 100644 index 0000000000000000000000000000000000000000..3a042712c701a46d5dac6cab554d3837e9920fc9 --- /dev/null +++ b/ola_vlm/model/language_model/base_ola_vlm.py @@ -0,0 +1,643 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput +from transformers.generation.utils import GenerateOutput + +from ola_vlm.model.aux_heads import GenHead, DepthHead, DAv2_Head, TaskTokenGenHead, TaskTokenDepthHead +from ola_vlm.model.aux_heads.depth_anything_v2.dpt import DepthAnythingV2 +from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead, OneFormerSegHead, OneFormerTaskTokenSegHead + +from transformers import OneFormerProcessor + +from diffusers import ( + DPMSolverMultistepScheduler, + StableUnCLIPImg2ImgPipeline, +) + +import torch.distributed as dist +try: + import wandb +except: + pass +import os +import matplotlib +from .base_lm import BaseCausalLM +from tqdm import tqdm + +from ola_vlm.ola_utils import * + + +class BaseOLA_VLM(BaseCausalLM): + + def __init__(self, config): + super(BaseCausalLM, self).__init__(config) + self.steps = 0 + self.config = config + + if hasattr(config, "image_gen"): + self.init_heads(config) + + try: + if dist.get_rank() == 0: + wandb.init(project=os.environ['WANDB_PROJECT'], name=f"{os.environ['WANDB_NAME']}") + except: + pass + + def get_model(self): + return self.model + + def init_target_models(self, config): + if hasattr(config, "image_gen") and "gen" in self.mode: + if not os.path.exists(config.image_generator): + config.image_generator = "stabilityai/stable-diffusion-2-1-unclip" + self.pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(config.image_generator, torch_dtype=torch.float16, variant="fp16") + self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config) + for p in self.pipe.image_encoder.parameters(): + p.requires_grad = False + try: + self.pipe = self.pipe.to("cuda") + except: + pass + + if hasattr(config, "image_depth") and "depth" in self.mode: + dav2_cfg = {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]} + self.dav2_backbone = DepthAnythingV2(**dav2_cfg) + + if not os.path.exists(config.depth_estimator): + url = "https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth?download=true" + local_model_path = "depth_anything_v2_vitl.pth" + if not os.path.exists(local_model_path): + os.system(f"wget -O {local_model_path} {url}") + config.depth_estimator = local_model_path + + config.depth_estimator = local_model_path + self.dav2_backbone.load_state_dict(torch.load(config.depth_estimator, map_location='cpu')) + for p in self.dav2_backbone.parameters(): + p.requires_grad = False + + if hasattr(config, "image_seg") and "seg" in self.mode: + if not os.path.exists(config.image_segmentor): + config.image_segmentor = "oneformer/oneformer_coco_swin_large" + self.oneformer_processor = OneFormerProcessor.from_pretrained(config.image_segmentor) + self.oneformer = OneFormerHead.from_pretrained(config.image_segmentor) + for p in self.oneformer.parameters(): + p.requires_grad = False + try: + self.oneformer = self.oneformer.to("cuda") + except: + pass + + def _get_layer_loss_weight(self, config, prefix): + layer_indices = config[f"{prefix}_layer_indices"] + layer_indices = layer_indices.split("-") + layer_indices = [int(i) - 1 for i in layer_indices] + loss_weight = config[f"{prefix}_loss_weight"] + return layer_indices, loss_weight + + def init_heads(self, config): + self.mode = getattr(config, "aux_mode", "gen-depth-seg") + self.pass_text_to_aux_head = getattr(config, "pass_text_to_aux", True) + self.use_ce = getattr(config, "use_ce", False) + self.contrastive_loss_weight = config.contrastive_loss_weight + num_task_tokens = config.num_task_tokens + + if hasattr(config, "image_gen") and "gen" in self.mode: + self.img_layer_indices, self.img_gen_loss_weight = self._get_layer_loss_weight(config.image_gen, "img") + if getattr(config, "use_contrastive", True): + self.gen_logit_scale = nn.Parameter(torch.tensor(2.0)) + else: + self.gen_logit_scale = None + + self.image_gen_heads = nn.ModuleList([ + TaskTokenGenHead(config.image_gen, llm_hidden_size=config.hidden_size) if num_task_tokens > 0 else GenHead(proj_config=config.image_gen, llm_hidden_size=config.hidden_size) + for _ in self.img_layer_indices + ]) + + if hasattr(config, "image_depth") and "depth" in self.mode: + self.depth_layer_indices, self.img_depth_loss_weight = self._get_layer_loss_weight(config.image_depth, "depth") + self.img_depth_loss_weight = config.image_depth["depth_loss_weight"] + + if getattr(config, "use_contrastive", True): + self.depth_logit_scale = nn.Parameter(torch.tensor(2.0)) + else: + self.depth_logit_scale = None + + self.use_intermediate_depth = config.image_depth.get("use_intermediate_depth", True) + + self.image_depth_heads = nn.ModuleList([ + TaskTokenDepthHead(proj_config=config.image_depth, llm_hidden_size=config.hidden_size, use_intermediate_depth=self.use_intermediate_depth) if num_task_tokens > 0 else DepthHead(proj_config=config.image_depth, llm_hidden_size=config.hidden_size, use_intermediate_depth=self.use_intermediate_depth) + for _ in self.depth_layer_indices + ]) + + self.da_v2_head = DAv2_Head() + + if not os.path.exists(config.depth_estimator): + url = "https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth?download=true" + local_model_path = "depth_anything_v2_vitl.pth" + if not os.path.exists(local_model_path): + os.system(f"wget -O {local_model_path} {url}") + config.depth_estimator = local_model_path + + self.da_v2_head.load_state_dict(torch.load(config.depth_estimator), strict=False) + + for p in self.da_v2_head.parameters(): + p.requires_grad = False + + if hasattr(config, "image_seg") and "seg" in self.mode: + self.seg_layer_indices, self.img_seg_loss_weight = self._get_layer_loss_weight(config.image_seg, "seg") + + self.seg_teacher = config.image_seg.get("seg_teacher", "sam") + + assert self.seg_teacher in ["sam", "oneformer"] + + if getattr(config, "use_contrastive", True): + self.seg_logit_scale = nn.Parameter(torch.tensor(2.0)) + else: + self.seg_logit_scale = None + + self.image_seg_heads = nn.ModuleList([ + OneFormerTaskTokenSegHead(config.image_seg, llm_hidden_size=config.hidden_size) if num_task_tokens > 0 else OneFormerSegHead(config.image_seg, llm_hidden_size=config.hidden_size) + for _ in self.seg_layer_indices + ]) + + + def log_gen(self, img_embeds, pil_images, layer_idx, is_train=False): + pipe = self.pipe.to("cuda") + + images = [] + + for img_embed in img_embeds: + image = pipe(image_embeds=img_embed.float().detach(), + num_inference_steps=25, + ).images[0] + images.append(image) + + if not is_train: + return images + + n = len(images) + c = min(n, 16) + r = n // c + images = images[:c*r] + image_grid = make_grid(images, pil_images) + + wandb.log({ + f"val_gen_images/step_{self.steps}": wandb.Image(image_grid, caption=f"Layer-{layer_idx}") + }) + + def log_depth(self, depth_preds, layer_idx, depth_targets=None, is_train=False): + cmap = matplotlib.colormaps.get_cmap('Spectral_r') + depth_preds = depth_preds.float().detach() + def _visualize_depth(depth): + depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + depth = depth.cpu().numpy().astype(np.uint8) + colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8) + return Image.fromarray(colored_depth) + + pred_depths, gt_depths = [], [] + + if depth_targets is None: + depth_targets = [None] * len(depth_preds) + + for pred, target in tqdm(zip(depth_preds, depth_targets), desc="Visualizing Depth..."): + if target is not None: + gt = _visualize_depth(target.float()) + gt_depths.append(gt) + + pred = _visualize_depth(pred) + pred_depths.append(pred) + + if not is_train: + return pred_depths + + n = len(pred_depths) + c = min(n, 16) + r = n // c + pred_depths = pred_depths[:c*r] + gt_depths = gt_depths[:c*r] + masks_grid = make_grid(pred_depths, gt_depths) + + wandb.log({ + f"val_depth_images/step_{self.steps}": wandb.Image(masks_grid, caption=f"Layer-{layer_idx}") + }) + + def log_seg(self, seg_embeds, pil_images, layer_idx, seg_targets=None, is_train=False): + def _oneformer_prepare_panoptic_instance_prediction( + segmentation: torch.Tensor, segments_info: dict + ): + masks = [] + classes = [] + + for segment in segments_info: + id = segment["id"] + label_id = segment["label_id"] + label = self.oneformer.config.id2label[label_id] + mask = segmentation == id + masks.append(mask.float()) + classes.append(label) + + return masks, classes + + pred_masks, gt_masks = [], [] + + seg_embeds = seg_embeds.detach() + + if seg_targets is None: + seg_targets = [None] * len(seg_embeds) + + for emb, target, img in tqdm(zip(seg_embeds, seg_targets, pil_images), desc=f"Predicting Segmentation Map..."): + with torch.no_grad(): + inputs = self.oneformer_processor(img, ["panoptic"], return_tensors="pt") + inputs["pixel_values"] = inputs["pixel_values"].to(emb.device, emb.dtype) + inputs["task_inputs"] = inputs["task_inputs"].to(emb.device, emb.dtype) + gt = self.oneformer.get_masks(**inputs, backbone_last_feature=target.unsqueeze(0)) + gt = self.oneformer_processor.post_process_panoptic_segmentation( + gt, target_sizes=[img.size[::-1]] + )[0] + gt_msk, gt_cls = _oneformer_prepare_panoptic_instance_prediction(**gt) + gt = visualize_oneformer_masks_on_image(img, gt_msk, gt_cls) + + pred = self.oneformer.get_masks(**inputs, backbone_last_feature=emb.unsqueeze(0)) + pred = self.oneformer_processor.post_process_panoptic_segmentation( + pred, target_sizes=[img.size[::-1]] + )[0] + pred_msk, pred_cls = _oneformer_prepare_panoptic_instance_prediction(**pred) + pred = visualize_oneformer_masks_on_image(img, pred_msk, pred_cls) + + gt_masks.append(gt) + pred_masks.append(pred) + + n = len(pred_masks) + c = min(n, 16) + r = n // c + pred_masks = pred_masks[:c*r] + gt_masks = gt_masks[:c*r] + masks_grid = make_grid(pred_masks, gt_masks) + + wandb.log({ + f"val_seg_images/step_{self.steps}": wandb.Image(masks_grid, caption=f"Layer-{layer_idx}") + }) + + + def _emb_loss(self, emb_preds, emb_mask, emb_targets, logit_scale): + emb_targets = emb_targets.to(emb_preds.dtype).to(emb_preds.device) + + if emb_targets.shape[0] != emb_preds.shape[0]: + repeat_factor = emb_preds.shape[0] // emb_targets.shape[0] + emb_targets = emb_targets.repeat(repeat_factor, 1, 1) + emb_mask = emb_mask.repeat(repeat_factor, 1, 1) + + if emb_targets.shape[0] != emb_preds.shape[0]: + emb_targets = emb_targets[:emb_preds.shape[0]] + emb_mask = emb_mask[:emb_preds.shape[0]] + + if emb_preds.ndim == 3: + emb_mask = emb_mask.view(emb_preds.shape[0], 1, 1) + else: + emb_mask = emb_mask.view(emb_preds.shape[0], 1, 1, 1) + + sl1_loss = F.smooth_l1_loss( + emb_preds.float(), emb_targets.float(), reduction="none" + ) + + if logit_scale is not None: + contrastive_loss = calculate_contrastive_loss(emb_preds, emb_targets, logit_scale) + else: + contrastive_loss = 0 + + sl1_loss = (sl1_loss * emb_mask.float()).mean() + contrastive_loss = (self.contrastive_loss_weight * contrastive_loss * emb_mask.float()).mean() + + emb_loss = sl1_loss + contrastive_loss + + return emb_loss, sl1_loss, contrastive_loss + + + def _get_gen_feats(self, pil_images, device): + gen_feats = [] + for img in pil_images: + with torch.no_grad(): + clip_ims = self.pipe.feature_extractor(images=img, return_tensors="pt").pixel_values.to(device) + feat = self.pipe.image_encoder(clip_ims).image_embeds + gen_feats.append(feat) + + gen_feats = torch.stack(gen_feats, dim=0) + return gen_feats + + def _forward_gen(self, gen_preds, layer_index, pil_images, gen_mask, gen_targets): + gen_loss, gen_sl1_loss, gen_cont_loss = self._emb_loss(gen_preds, gen_mask, gen_targets, self.gen_logit_scale) + + if dist.get_rank() == 0: + if self.steps % 4000 == 0: + try: + self.log_gen(gen_preds.detach(), pil_images, layer_index, is_train=True) + except: + pass + + return gen_loss, gen_cont_loss, gen_sl1_loss + + + def _get_dav2_feats(self, pil_images, device): + dav2_gts = [] + depth_targets = [[]] + for img in pil_images: + img = img.resize((336, 336)) + img = np.array(img) + with torch.no_grad(): + feat = self.dav2_backbone.infer_image(img, is_dsg=True) + ft_gt = (feat[0][0] + feat[1][0] + feat[2][0] + feat[3][0]) / 4 + depth_gts = self.da_v2_head([(ft_gt, None)] * 4) + depth_targets[0].append(ft_gt) + min_val = depth_gts.amin(dim=(1, 2), keepdim=True) + max_val = depth_gts.amax(dim=(1, 2), keepdim=True) + depth_gts = (depth_gts - min_val) / (max_val - min_val) + dav2_gts.append(depth_gts.to(device)) + dav2_gts = torch.stack(dav2_gts, dim=0).squeeze(1) + for i in range(len(depth_targets)): + depth_targets[i] = (torch.stack(depth_targets[i], dim=0).squeeze(1), None) + return depth_targets, dav2_gts + + def _forward_depth(self, all_depth_feats, layer_index, depth_mask, all_depth_targets, depth_pred_maps, depth_gts): + + depth_feats, depth_targets = all_depth_feats[0][0], all_depth_targets[0][0] + depth_loss, sl1_loss, cont_loss = self._emb_loss(depth_feats, depth_mask, depth_targets, self.depth_logit_scale) + + if dist.get_rank() == 0: + if self.steps % 1000 == 0: + try: + self.log_depth(depth_pred_maps.detach(), layer_index, depth_gts, is_train=True) + except: + pass + + return depth_loss, sl1_loss, cont_loss + + + def _get_seg_targets(self, pil_images, seg_preds): + def _get_feats(img): + img = img.resize((768, 768)) + inputs = self.oneformer_processor(img, ["panoptic"], return_tensors="pt") + inputs["pixel_values"] = inputs["pixel_values"].to(seg_preds.device, seg_preds.dtype) + with torch.no_grad(): + feats = self.oneformer.forward_features(**inputs) + return feats + + seg_targets = [] + for img in pil_images: + feat = _get_feats(img) + seg_targets.append(feat) + + seg_targets = torch.stack(seg_targets, dim=0).squeeze(1) + return seg_targets + + def _forward_seg(self, seg_preds, layer_index, pil_images, seg_targets, seg_mask): + + seg_loss, sl1_loss, cont_loss = self._emb_loss(seg_preds, seg_mask, seg_targets, self.seg_logit_scale) + + if dist.get_rank() == 0: + if self.steps % 1000 == 0: + try: + self.log_seg(seg_preds.detach(), pil_images, layer_index, seg_targets, is_train=True) + except: + pass + + return seg_loss, sl1_loss, cont_loss + + + def forward_emb_predictor(self, layer_states, idx, i, task, heads, special_tokens): + task_idx = self.token_order.index(task) + task_start_idx = self.NUM_SYS_TOKENS + 576 + (self.num_task_tokens * task_idx) + task_end_idx = task_start_idx + self.num_task_tokens + end_idx = self.NUM_SYS_TOKENS + 576 + (self.num_task_tokens * len(self.token_order)) + + inp_tokens = layer_states[idx][:, :self.NUM_SYS_TOKENS+576] + + if self.num_task_tokens == 0 or layer_states[idx].shape[1] < 600: + if self.pass_text_to_aux_head: + inp_tokens = layer_states[idx] + else: + inp_tokens = torch.cat([inp_tokens, layer_states[idx][:, task_start_idx:task_end_idx]], dim=1) + if self.pass_text_to_aux_head: + inp_tokens = torch.cat([inp_tokens, layer_states[idx][:, end_idx:]], dim=1) + + if self.num_task_tokens == 0: + task_emb = heads[i](inp_tokens) + else: + task_tokens = special_tokens + if task != "gen": + task_tokens = task_tokens.repeat(inp_tokens.shape[0], 1, 1) + else: + if not self.pass_text_to_aux_head: + task_tokens = inp_tokens[:, -self.num_task_tokens:] + else: + task_tokens = inp_tokens[:, self.NUM_SYS_TOKENS+576:self.NUM_SYS_TOKENS+576+self.num_task_tokens] + + task_emb = heads[i](inp_tokens, task_tokens) + + return task_emb + + def depth_emb_forward(self, pil_images, layer_states, depth_mask): + depth_preds = [] + depth_embs = [] + depth_loss = 0 + depth_l1_loss = 0 + depth_cont_loss = 0 + if "depth" in self.mode and layer_states[0].shape[1] > self.NUM_SYS_TOKENS: + if pil_images is not None: + depth_targets, depth_gts = self._get_dav2_feats(pil_images, layer_states[0].device) + else: + depth_targets, depth_gts = None, None + + for i, idx in enumerate(self.depth_layer_indices): + + depth_feats = self.forward_emb_predictor(layer_states, idx, i, "depth", self.image_depth_heads, self.depth_tokens) + depth_embs.append(depth_feats) + + with torch.no_grad(): + if self.use_intermediate_depth: + depth_pred = self.da_v2_head(depth_feats) + else: + depth_pred = self.da_v2_head([depth_feats[0]] * 4) + min_val = depth_pred.amin(dim=(1, 2), keepdim=True) + max_val = depth_pred.amax(dim=(1, 2), keepdim=True) + depth_pred = (depth_pred - min_val) / (max_val - min_val) + depth_preds.append(depth_pred) + + if depth_mask is not None: + depth_mask.zero_() + + if depth_targets is not None: + layer_depth_loss, layer_l1_loss, layer_cont_loss = self._forward_depth(depth_feats, idx+1, depth_mask, depth_targets, depth_pred, depth_gts) + depth_loss += layer_depth_loss * self.img_depth_loss_weight + depth_l1_loss += layer_l1_loss * self.img_depth_loss_weight + depth_cont_loss += layer_cont_loss * self.img_depth_loss_weight + + return depth_preds, depth_embs, depth_loss, depth_l1_loss, depth_cont_loss + + def seg_emb_forward(self, pil_images, hidden_states, layer_states, seg_mask): + seg_embs = [] + seg_loss = 0 + seg_l1_loss = 0 + seg_contrastive_loss = 0 + if "seg" in self.mode and layer_states[0].shape[1] > self.NUM_SYS_TOKENS: + if pil_images is not None: + seg_targets = self._get_seg_targets(pil_images, hidden_states) + else: + seg_targets = None + for i, idx in enumerate(self.seg_layer_indices): + + seg_emb = self.forward_emb_predictor(layer_states, idx, i, "seg", self.image_seg_heads, self.seg_tokens) + seg_embs.append(seg_emb) + + if seg_mask is not None: + seg_mask.zero_() + + if seg_targets is not None: + layer_seg_loss, seg_l1_loss, seg_contrastive_loss = self._forward_seg(seg_emb, idx+1, pil_images, seg_targets, seg_mask) + seg_loss += layer_seg_loss * self.img_seg_loss_weight + seg_l1_loss += seg_l1_loss * self.img_seg_loss_weight + seg_contrastive_loss += seg_contrastive_loss * self.img_seg_loss_weight + + return seg_embs, seg_loss, seg_l1_loss, seg_contrastive_loss + + def gen_emb_forward(self, pil_images, hidden_states, layer_states, gen_mask): + img_embs = [] + gen_loss = 0 + gen_con_loss = 0 + gen_mse_loss = 0 + if "gen" in self.mode and layer_states[0].shape[1] > self.NUM_SYS_TOKENS: + if pil_images is not None: + gen_targets = self._get_gen_feats(pil_images, hidden_states.device) + else: + gen_targets = None + + for i, idx in enumerate(self.img_layer_indices): + + img_emb = self.forward_emb_predictor(layer_states, idx, i, "gen", self.image_gen_heads, self.gen_tokens) + img_embs.append(img_emb) + + if gen_mask is not None: + gen_mask.zero_() + + if gen_targets is not None: + layer_gen_loss, layer_gen_con_loss, layer_gen_mse_loss = self._forward_gen(img_emb, idx+1, pil_images, gen_mask, gen_targets) + gen_loss += layer_gen_loss * self.img_gen_loss_weight + gen_con_loss += layer_gen_con_loss * self.img_gen_loss_weight + gen_mse_loss += layer_gen_mse_loss * self.img_gen_loss_weight + + return img_embs, gen_loss, gen_mse_loss, gen_con_loss + + + @torch.no_grad() + def get_visual_interpretations( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, CausalLMOutputWithPast]: + + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if True: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + + + return self.forward( + input_ids=inputs, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + attention_mask=attention_mask, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=True, + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + pil_images = kwargs.pop("pil_images", None) + + depth_mask = kwargs.pop("seg_mask", None) + gen_mask = kwargs.pop("seg_mask", None) + seg_mask = kwargs.pop("seg_mask", None) + + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if images is not None: + inputs['images'] = images + if image_sizes is not None: + inputs['image_sizes'] = image_sizes + if pil_images is not None: + inputs['pil_images'] = pil_images + if depth_mask is not None: + inputs['depth_mask'] = depth_mask + if gen_mask is not None: + inputs['gen_mask'] = gen_mask + if seg_mask is not None: + inputs['seg_mask'] = seg_mask + return inputs \ No newline at end of file diff --git a/ola_vlm/model/language_model/base_probe_vlm.py b/ola_vlm/model/language_model/base_probe_vlm.py new file mode 100644 index 0000000000000000000000000000000000000000..df7dfdab6a79d95631fd2599b806a7d6fe8af812 --- /dev/null +++ b/ola_vlm/model/language_model/base_probe_vlm.py @@ -0,0 +1,546 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput +from transformers.generation.utils import GenerateOutput + +from ola_vlm.model.aux_heads import GenHead, DepthHead, DAv2_Head +from ola_vlm.model.aux_heads.depth_anything_v2.dpt import DepthAnythingV2 +from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead, OneFormerSegHead + +from transformers import OneFormerProcessor + +from diffusers import ( + DPMSolverMultistepScheduler, + StableUnCLIPImg2ImgPipeline, +) + +import torch.distributed as dist +try: + import wandb +except: + pass +import os +import matplotlib +from .base_lm import BaseCausalLM +from tqdm import tqdm + +from ola_vlm.ola_utils import * + +class BaseProbe_VLM(BaseCausalLM): + + def __init__(self, config): + super(BaseCausalLM, self).__init__(config) + self.steps = 0 + self.config = config + self.num_layers = config.num_hidden_layers + + # Initialize weights and apply final processing + self.post_init() + self.is_trained = False + if hasattr(config, "probe_mode"): + self.is_trained = True + self.init_heads(config) + + try: + if dist.get_rank() == 0: + wandb.init(project=os.environ['WANDB_PROJECT'], name=f"{os.environ['WANDB_NAME']}") + except: + pass + + def get_model(self): + return self.model + + def init_heads(self, config): + self.mode = config.probe_mode + + if self.mode == "gen": + self.image_gen_heads = nn.ModuleList([ + GenHead(config.image_gen, llm_hidden_size=config.hidden_size) + for _ in range(self.num_layers) + ]) + + if not self.is_trained: + self.pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(config.image_generator, torch_dtype=torch.float16, variant="fp16") + self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config) + self.gen_encoder = self.pipe.image_encoder + self.feature_extractor = self.pipe.feature_extractor + for p in self.gen_encoder.parameters(): + p.requires_grad = False + + elif self.mode == "seg": + if not self.is_trained: + self.oneformer_processor = OneFormerProcessor.from_pretrained(config.image_segmentor) + self.oneformer = OneFormerHead.from_pretrained(config.image_segmentor) + for p in self.oneformer.parameters(): + p.requires_grad = False + try: + self.oneformer = self.oneformer.to("cuda") + except: + pass + self.image_seg_heads = nn.ModuleList([ + OneFormerSegHead(config.image_seg, llm_hidden_size=config.hidden_size) + for _ in range(self.num_layers) + ]) + + + if self.mode == "depth": + self.image_depth_heads = nn.ModuleList([ + DepthHead(proj_config=config.image_depth, llm_hidden_size=config.hidden_size, use_intermediate_depth=False) + for _ in range(self.num_layers) + ]) + + dav2_cfg = {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]} + self.dav2_backbone = DepthAnythingV2(**dav2_cfg) + self.dav2_backbone.load_state_dict(torch.load(config.depth_estimator, map_location='cpu')) + for p in self.dav2_backbone.parameters(): + p.requires_grad = False + + self.da_v2_head = DAv2_Head() + self.da_v2_head.load_state_dict(torch.load(config.depth_estimator), strict=False) + for p in self.da_v2_head.parameters(): + p.requires_grad = False + + def _get_layer_loss_weight(self, config, prefix): + layer_indices = config[f"{prefix}_layer_indices"] + layer_indices = layer_indices.split("-") + layer_indices = [int(i) - 1 for i in layer_indices] + loss_weight = config[f"{prefix}_loss_weight"] + return layer_indices, loss_weight + + def log_gen(self, img_embeds, pil_images, layer_idx, is_train=False): + device = "cuda" if torch.cuda.is_available() else "hip" + pipe = self.pipe.to(device) + + images = [] + + if len(pil_images) > 2: + pil_images = pil_images[:2] + img_embeds = img_embeds[:2] + + for img_embed in img_embeds: + image = pipe(image_embeds=img_embed.float().detach(), + num_inference_steps=25, + # guidance_scale=1,, + ).images[0] + images.append(image) + + if not is_train: + return images + + n = len(images) + c = min(n, 16) + r = n // c + images = images[:c*r] + image_grid = make_grid(images, pil_images) + + wandb.log({ + f"val_gen_images/step_{self.steps}": wandb.Image(image_grid, caption=f"Layer-{layer_idx}") + }) + + def log_depth(self, depth_preds, layer_idx, depth_targets=None, is_train=False): + cmap = matplotlib.colormaps.get_cmap('Spectral_r') + depth_preds = depth_preds.float().detach() + def _visualize_depth(depth): + depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + depth = depth.cpu().numpy().astype(np.uint8) + colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8) + return Image.fromarray(colored_depth) + + pred_depths, gt_depths = [], [] + + if depth_targets is None: + depth_targets = [None] * len(depth_preds) + + from tqdm import tqdm + for pred, target in tqdm(zip(depth_preds, depth_targets), desc="Visualizing Depth..."): + if target is not None: + gt = _visualize_depth(target.float()) + gt_depths.append(gt) + + pred = _visualize_depth(pred) + pred_depths.append(pred) + + if not is_train: + return pred_depths + + n = len(pred_depths) + c = min(n, 16) + r = n // c + pred_depths = pred_depths[:c*r] + gt_depths = gt_depths[:c*r] + masks_grid = make_grid(pred_depths, gt_depths) + + wandb.log({ + f"val_depth_images/step_{self.steps}": wandb.Image(masks_grid, caption=f"Layer-{layer_idx}") + }) + + def log_seg(self, seg_embeds, pil_images, layer_idx, seg_targets=None, is_train=False): + def _oneformer_prepare_panoptic_instance_prediction( + segmentation: torch.Tensor, segments_info: dict + ): + masks = [] + classes = [] + + for segment in segments_info: + id = segment["id"] + label_id = segment["label_id"] + label = self.oneformer.config.id2label[label_id] + mask = segmentation == id + masks.append(mask.float()) + classes.append(label) + + return masks, classes + + pred_masks, gt_masks = [], [] + + seg_embeds = seg_embeds.detach() + + if seg_targets is None: + seg_targets = [None] * len(seg_embeds) + + if len(pil_images) > 2: + pil_images = pil_images[:2] + seg_embeds = seg_embeds[:2] + seg_targets = seg_targets[:2] + + from tqdm import tqdm + for emb, target, img in tqdm(zip(seg_embeds, seg_targets, pil_images), desc=f"Predicting Segmentation Map..."): + with torch.no_grad(): + inputs = self.oneformer_processor(img, ["panoptic"], return_tensors="pt") + inputs["pixel_values"] = inputs["pixel_values"].to(emb.device, emb.dtype) + inputs["task_inputs"] = inputs["task_inputs"].to(emb.device, emb.dtype) + gt = self.oneformer.get_masks(**inputs, backbone_last_feature=target.unsqueeze(0)) + gt = self.oneformer_processor.post_process_panoptic_segmentation( + gt, target_sizes=[img.size[::-1]] + )[0] + gt_msk, gt_cls = _oneformer_prepare_panoptic_instance_prediction(**gt) + gt = visualize_oneformer_masks_on_image(img, gt_msk, gt_cls) + + pred = self.oneformer.get_masks(**inputs, backbone_last_feature=emb.unsqueeze(0)) + pred = self.oneformer_processor.post_process_panoptic_segmentation( + pred, target_sizes=[img.size[::-1]] + )[0] + pred_msk, pred_cls = _oneformer_prepare_panoptic_instance_prediction(**pred) + pred = visualize_oneformer_masks_on_image(img, pred_msk, pred_cls) + + gt_masks.append(gt) + pred_masks.append(pred) + + if not is_train: + return pred_masks + + n = len(pred_masks) + c = min(n, 16) + r = n // c + pred_masks = pred_masks[:c*r] + gt_masks = gt_masks[:c*r] + masks_grid = make_grid(pred_masks, gt_masks) + + wandb.log({ + f"val_seg_images/step_{self.steps}": wandb.Image(masks_grid, caption=f"Layer-{layer_idx}") + }) + + + def _emb_loss(self, emb_preds, emb_targets): + emb_targets = emb_targets.to(emb_preds.dtype).to(emb_preds.device) + + if emb_targets.shape[0] != emb_preds.shape[0]: + repeat_factor = emb_preds.shape[0] // emb_targets.shape[0] + emb_targets = emb_targets.repeat(repeat_factor, 1, 1) + + if emb_targets.shape[0] != emb_preds.shape[0]: + emb_targets = emb_targets[:emb_preds.shape[0]] + emb_mask = emb_mask[:emb_preds.shape[0]] + + emb_loss = F.smooth_l1_loss( + emb_preds.float(), emb_targets.float(), reduction="none" + ).mean() + + return emb_loss + + + def _get_gen_feats(self, pil_images, device): + gen_feats = [] + for img in pil_images: + with torch.no_grad(): + clip_ims = self.pipe.feature_extractor(images=img, return_tensors="pt").pixel_values.to(device) + feat = self.pipe.image_encoder(clip_ims).image_embeds + gen_feats.append(feat) + + gen_feats = torch.stack(gen_feats, dim=0) + return gen_feats + + def _forward_gen(self, gen_preds, layer_index, pil_images, gen_targets): + gen_loss = self._emb_loss(gen_preds, gen_targets) + + if dist.get_rank() == 0: + if self.steps % 500 == 0: + try: + self.log_gen(gen_preds.detach(), pil_images, layer_index, is_train=True) + except: + pass + + return gen_loss + + + def _get_dav2_feats(self, pil_images, device): + dav2_gts = [] + depth_targets = [[]] + for img in pil_images: + img = img.resize((336, 336)) + img = np.array(img) + with torch.no_grad(): + feat = self.dav2_backbone.infer_image(img, is_dsg=True) + depth_gts = self.da_v2_head([feat[-1]] * 4) + depth_targets[0].append(feat[-1][0]) + min_val = depth_gts.amin(dim=(1, 2), keepdim=True) + max_val = depth_gts.amax(dim=(1, 2), keepdim=True) + depth_gts = (depth_gts - min_val) / (max_val - min_val) + dav2_gts.append(depth_gts.to(device)) + dav2_gts = torch.stack(dav2_gts, dim=0).squeeze(1) + for i in range(len(depth_targets)): + depth_targets[i] = (torch.stack(depth_targets[i], dim=0).squeeze(1), None) + return depth_targets, dav2_gts + + def _forward_depth(self, all_depth_feats, layer_index, all_depth_targets, depth_pred_maps, depth_gts): + + depth_feats, depth_targets = all_depth_feats[0][0], all_depth_targets[0][0] + depth_loss = self._emb_loss(depth_feats, depth_targets) + + if dist.get_rank() == 0: + if self.steps % 200 == 0: + try: + self.log_depth(depth_pred_maps.detach(), layer_index, depth_gts, is_train=True) + except: + pass + + return depth_loss + + + def _get_seg_targets(self, pil_images, seg_preds): + def _get_feats(img): + img = img.resize((768, 768)) + inputs = self.oneformer_processor(img, ["panoptic"], return_tensors="pt") + inputs["pixel_values"] = inputs["pixel_values"].to(seg_preds.device, seg_preds.dtype) + with torch.no_grad(): + feats = self.oneformer.forward_features(**inputs) + return feats + + seg_targets = [] + for img in pil_images: + feat = _get_feats(img) + seg_targets.append(feat) + + seg_targets = torch.stack(seg_targets, dim=0).squeeze(1) + return seg_targets + + def _forward_seg(self, seg_preds, layer_index, pil_images, seg_targets): + + seg_loss = self._emb_loss(seg_preds, seg_targets) + + if dist.get_rank() == 0: + if self.steps % 200 == 0: + try: + self.log_seg(seg_preds.detach(), pil_images, layer_index, seg_targets, is_train=True) + except: + pass + + return seg_loss + + + def forward_emb_predictor(self, layer_states, idx, i, heads): + inp_tokens = layer_states[idx] + task_emb = heads[i](inp_tokens) + return task_emb + + def depth_emb_forward(self, pil_images, layer_states): + depth_preds = [] + depth_embs = [] + depth_loss = 0 + log_dict = {} + if self.mode == "depth": + if pil_images is not None: + depth_targets, depth_gts = self._get_dav2_feats(pil_images, layer_states[0].device) + else: + depth_targets, depth_gts = None, None + + for i, idx in enumerate(self.num_layers): + + depth_feats = self.forward_emb_predictor(layer_states, idx, i, self.image_depth_heads) + depth_embs.append(depth_feats) + + with torch.no_grad(): + depth_pred = self.da_v2_head([depth_feats[0]] * 4) + min_val = depth_pred.amin(dim=(1, 2), keepdim=True) + max_val = depth_pred.amax(dim=(1, 2), keepdim=True) + depth_pred = (depth_pred - min_val) / (max_val - min_val) + depth_preds.append(depth_pred) + + if depth_targets is not None: + layer_depth_loss = self._forward_depth(depth_feats, idx+1, depth_targets, depth_pred, depth_gts) + depth_loss += layer_depth_loss + if dist.get_rank() == 0: + log_dict = { + **log_dict, + f"{idx}_depth_loss": layer_depth_loss.item(), + } + + + return depth_preds, depth_embs, depth_loss, log_dict + + def seg_emb_forward(self, pil_images, hidden_states, layer_states): + seg_embs = [] + seg_loss = 0 + log_dict = {} + if "seg" in self.mode: + if pil_images is not None: + seg_targets = self._get_seg_targets(pil_images, hidden_states) + else: + seg_targets = None + for i, idx in enumerate(self.num_layers): + + seg_emb = self.forward_emb_predictor(layer_states, idx, i, "seg", self.image_seg_heads) + seg_embs.append(seg_emb) + + if seg_targets is not None: + layer_seg_loss = self._forward_seg(seg_emb, idx+1, pil_images, seg_targets) + seg_loss += layer_seg_loss + if dist.get_rank() == 0: + log_dict = { + **log_dict, + f"{idx}_seg_loss": layer_seg_loss.item(), + } + + + return seg_embs, seg_loss, log_dict + + def gen_emb_forward(self, pil_images, hidden_states, layer_states): + img_embs = [] + gen_loss = 0 + log_dict = {} + if "gen" in self.mode: + if pil_images is not None: + gen_targets = self._get_gen_feats(pil_images, hidden_states.device) + else: + gen_targets = None + + for i, idx in enumerate(self.num_layers): + + img_emb = self.forward_emb_predictor(layer_states, idx, i, "gen", self.image_gen_heads) + img_embs.append(img_emb) + + if gen_targets is not None: + layer_gen_loss = self._forward_gen(img_emb, idx+1, pil_images, gen_targets) + gen_loss += layer_gen_loss + if dist.get_rank() == 0: + log_dict = { + **log_dict, + f"{idx}_gen_loss": layer_gen_loss.item(), + } + + return img_embs, gen_loss, log_dict + + @torch.no_grad() + def get_visual_interpretations( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, CausalLMOutputWithPast]: + + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if True: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + + + return self.forward( + input_ids=inputs, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + attention_mask=attention_mask, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=True, + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + pil_images = kwargs.pop("pil_images", None) + + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if images is not None: + inputs['images'] = images + if image_sizes is not None: + inputs['image_sizes'] = image_sizes + if pil_images is not None: + inputs['pil_images'] = pil_images + return inputs \ No newline at end of file diff --git a/ola_vlm/model/language_model/llava_llama.py b/ola_vlm/model/language_model/llava_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..dc52083c8a67c194df558e81769c81c3aedf6773 --- /dev/null +++ b/ola_vlm/model/language_model/llava_llama.py @@ -0,0 +1,175 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers import AutoConfig, AutoModelForCausalLM, \ + LlamaConfig, LlamaModel, LlamaForCausalLM + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput +from transformers.utils import logging + +import warnings +logger = logging.get_logger(__name__) + +from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM +from .base_lm import BaseCausalLM +import torch.distributed as dist +try: + import wandb +except: + pass +import os + +class LlavaConfig(LlamaConfig): + model_type = "llava_llama" + + +class LlavaLlamaModel(LlavaMetaModel, LlamaModel): + config_class = LlavaConfig + + def __init__(self, config: LlamaConfig): + super(LlavaLlamaModel, self).__init__(config) + + +class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM, BaseCausalLM): + config_class = LlavaConfig + + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + self.model = LlavaLlamaModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + try: + if dist.get_rank() == 0: + wandb.init(project=os.environ['WANDB_PROJECT'], name=f"{os.environ['WANDB_NAME']}") + except: + pass + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + image_sizes + ) + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if images is not None: + inputs['images'] = images + if image_sizes is not None: + inputs['image_sizes'] = image_sizes + return inputs + +AutoConfig.register("llava_llama", LlavaConfig) +AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) diff --git a/ola_vlm/model/language_model/llava_phi3.py b/ola_vlm/model/language_model/llava_phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..c435c7db107cb2cbd46a40e4b65d588fb4da923a --- /dev/null +++ b/ola_vlm/model/language_model/llava_phi3.py @@ -0,0 +1,175 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers import AutoConfig, AutoModelForCausalLM, \ + Phi3Config, Phi3Model, Phi3ForCausalLM + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput +from transformers.utils import logging + +import torch.distributed as dist +try: + import wandb +except: + pass +import os + +logger = logging.get_logger(__name__) + +from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM +from .base_lm import BaseCausalLM + + +class LlavaPhi3Config(Phi3Config): + model_type = "llava_phi3" + + +class LlavaPhi3Model(LlavaMetaModel, Phi3Model): + config_class = LlavaPhi3Config + + def __init__(self, config: Phi3Config): + super(LlavaPhi3Model, self).__init__(config) + + +class LlavaPhi3ForCausalLM(Phi3ForCausalLM, LlavaMetaForCausalLM, BaseCausalLM): + config_class = LlavaPhi3Config + + def __init__(self, config): + super(Phi3ForCausalLM, self).__init__(config) + self.model = LlavaPhi3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + try: + if dist.get_rank() == 0: + wandb.init(project=os.environ['WANDB_PROJECT'], name=f"{os.environ['WANDB_NAME']}") + except: + pass + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + image_sizes + ) + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if images is not None: + inputs['images'] = images + if image_sizes is not None: + inputs['image_sizes'] = image_sizes + return inputs + +AutoConfig.register("llava_phi3", LlavaPhi3Config) +AutoModelForCausalLM.register(LlavaPhi3Config, LlavaPhi3ForCausalLM) diff --git a/ola_vlm/model/language_model/ola_llama.py b/ola_vlm/model/language_model/ola_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..60902c1ef4a572c9139eed44fb0a37681d0e9efb --- /dev/null +++ b/ola_vlm/model/language_model/ola_llama.py @@ -0,0 +1,248 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers import AutoConfig, AutoModelForCausalLM, \ + LlamaConfig, LlamaForCausalLM, LlamaModel + +from transformers.modeling_outputs import CausalLMOutputWithPast +from dataclasses import dataclass + +from ..ola_arch import OlaLlavaMetaModel, OlaLlavaMetaForCausalLM +import torch.distributed as dist +try: + import wandb +except: + pass +from torch.nn import CrossEntropyLoss +from .base_lm import BaseCausalLM +from .base_ola_vlm import BaseOLA_VLM + + + +@dataclass +class OlaCausalLLMOutputWithPast(CausalLMOutputWithPast): + image_embs: Optional[Tuple[torch.FloatTensor]] = None + seg_embs: Optional[Tuple[torch.FloatTensor]] = None + depth_embs: Optional[Tuple[torch.FloatTensor]] = None + depth_preds: Optional[Tuple[torch.FloatTensor]] = None + + +class OlaLlavaLlamaConfig(LlamaConfig): + model_type = "ola_llama" + + +class OlaLlavaLlamaModel(OlaLlavaMetaModel, LlamaModel): + config_class = OlaLlavaLlamaConfig + + def __init__(self, config: LlamaConfig): + super(OlaLlavaLlamaModel, self).__init__(config) + + +class OlaLlavaLlamaForCausalLM(LlamaForCausalLM, OlaLlavaMetaForCausalLM, BaseOLA_VLM): + config_class = OlaLlavaLlamaConfig + + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + self.model = OlaLlavaLlamaModel(config) + self.vocab_size = config.vocab_size + if self.vocab_size < 128000: + self.NUM_SYS_TOKENS = 26 # vicuna-7b + else: + self.NUM_SYS_TOKENS = 38 # llama3-8b + print(f"Number of System Tokens: {self.NUM_SYS_TOKENS}") + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.config = config + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def _forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pil_images = None, + gen_mask: Optional[torch.FloatTensor] = None, + seg_mask: Optional[torch.FloatTensor] = None, + depth_mask: Optional[torch.FloatTensor] = None, + + ) -> Union[Tuple, OlaCausalLLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + layer_states = outputs[-1][1:] + + logits = self.lm_head(hidden_states) + logits = logits.float() + + text_loss = None + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + text_loss = loss_fct(shift_logits, shift_labels) + + + depth_preds, depth_embs, depth_loss, depth_l1_loss, depth_cont_loss = self.depth_emb_forward(pil_images, layer_states, depth_mask) + seg_embs, seg_loss, seg_l1_loss, seg_contrastive_loss = self.seg_emb_forward(pil_images, hidden_states, layer_states, seg_mask) + img_embs, gen_loss, gen_mse_loss, gen_con_loss = self.gen_emb_forward(pil_images, hidden_states, layer_states, gen_mask) + + if text_loss is not None: + loss = text_loss + seg_loss + depth_loss + gen_loss + + try: + if dist.get_rank() == 0: + if loss > text_loss: + log_dict = { + "depth_loss": depth_loss, + "gen_loss": gen_loss, + "depth_l1_loss": depth_l1_loss, + "depth_contrastive_loss": depth_cont_loss, + "dinov2_loss": dinov2_loss, + "gen_mse_loss": gen_mse_loss, + "gen_contrastive_loss": gen_con_loss, + "seg_loss": seg_loss, + "seg_l1_loss": seg_l1_loss, + "seg_contrastive_loss": seg_contrastive_loss, + "text_loss": text_loss, + "loss": loss, + } + filtered_log_dict = {key: value for key, value in log_dict.items() if value > 0} + wandb.log(filtered_log_dict) + else: + wandb.log({ + "text_loss": text_loss, + "loss": loss, + }) + + self.steps += 1 + except: + pass + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return OlaCausalLLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_embs=img_embs, + seg_embs=seg_embs, + depth_embs=depth_embs, + depth_preds=depth_preds, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + pil_images: Optional[List[object]] = None, + gen_mask: Optional[torch.FloatTensor] = None, + seg_mask: Optional[torch.FloatTensor] = None, + depth_mask: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + image_sizes + ) + + return self._forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pil_images=pil_images, + gen_mask=gen_mask, + seg_mask=seg_mask, + depth_mask=depth_mask, + ) + +AutoConfig.register("ola_llama", OlaLlavaLlamaConfig) +AutoModelForCausalLM.register(OlaLlavaLlamaConfig, OlaLlavaLlamaForCausalLM) \ No newline at end of file diff --git a/ola_vlm/model/language_model/ola_phi3.py b/ola_vlm/model/language_model/ola_phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..493b9e83b1ec36600eb73d32a6f4a29e8f72d130 --- /dev/null +++ b/ola_vlm/model/language_model/ola_phi3.py @@ -0,0 +1,248 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers import AutoConfig, AutoModelForCausalLM, \ + Phi3Config, Phi3Model, Phi3ForCausalLM + +from transformers.modeling_outputs import CausalLMOutputWithPast +from dataclasses import dataclass + +from ..ola_arch import OlaLlavaMetaModel, OlaLlavaMetaForCausalLM + +import torch.distributed as dist +try: + import wandb +except: + pass +import os +from torch.nn import CrossEntropyLoss +from .base_lm import BaseCausalLM +from .base_ola_vlm import BaseOLA_VLM + +@dataclass +class OlaCausalLLMOutputWithPast(CausalLMOutputWithPast): + image_embs: Optional[Tuple[torch.FloatTensor]] = None + seg_embs: Optional[Tuple[torch.FloatTensor]] = None + depth_embs: Optional[Tuple[torch.FloatTensor]] = None + depth_preds: Optional[Tuple[torch.FloatTensor]] = None + + +class OlaLlavaPhi3Config(Phi3Config): + model_type = "ola_phi3" + + +class OlaLlavaPhi3Model(OlaLlavaMetaModel, Phi3Model): + config_class = OlaLlavaPhi3Config + + def __init__(self, config: Phi3Config): + super(OlaLlavaPhi3Model, self).__init__(config) + + +class OlaLlavaPhi3ForCausalLM(Phi3ForCausalLM, OlaLlavaMetaForCausalLM, BaseOLA_VLM): + config_class = OlaLlavaPhi3Config + + def __init__(self, config): + super(Phi3ForCausalLM, self).__init__(config) + self.model = OlaLlavaPhi3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.config = config + self.is_use_reference_model = False + self.NUM_SYS_TOKENS = 13 + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def _forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pil_images = None, + gen_mask: Optional[torch.FloatTensor] = None, + seg_mask: Optional[torch.FloatTensor] = None, + depth_mask: Optional[torch.FloatTensor] = None, + + ) -> Union[Tuple, OlaCausalLLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + layer_states = outputs[-1][1:] + + logits = self.lm_head(hidden_states) + logits = logits.float() + + text_loss = None + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + text_loss = loss_fct(shift_logits, shift_labels) + + + dinov2_emb, dinov2_loss = self.dinov2_emb_forward(pil_images, layer_states, hidden_states, seg_mask) + depth_preds, depth_embs, depth_loss, depth_l1_loss, depth_cont_loss = self.depth_emb_forward(pil_images, layer_states, depth_mask) + seg_embs, seg_loss, seg_l1_loss, seg_contrastive_loss = self.seg_emb_forward(pil_images, hidden_states, layer_states, seg_mask) + img_embs, gen_loss, gen_mse_loss, gen_con_loss = self.gen_emb_forward(pil_images, hidden_states, layer_states, gen_mask) + + + if text_loss is not None: + loss = text_loss + seg_loss + depth_loss + gen_loss + dinov2_loss + + try: + if dist.get_rank() == 0: + if loss > text_loss: + log_dict = { + "depth_loss": depth_loss, + "gen_loss": gen_loss, + "depth_l1_loss": depth_l1_loss, + "depth_contrastive_loss": depth_cont_loss, + "dinov2_loss": dinov2_loss, + "gen_mse_loss": gen_mse_loss, + "gen_contrastive_loss": gen_con_loss, + "seg_loss": seg_loss, + "seg_l1_loss": seg_l1_loss, + "seg_cosine-emb_loss": seg_ce_loss, + "seg_contrastive_loss": seg_contrastive_loss, + "text_loss": text_loss, + "loss": loss, + } + filtered_log_dict = {key: value for key, value in log_dict.items() if value > 0} + wandb.log(filtered_log_dict) + else: + wandb.log({ + "text_loss": text_loss, + "loss": loss, + }) + + self.steps += 1 + except: + pass + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return OlaCausalLLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_embs=img_embs, + seg_embs=seg_embs, + depth_embs=depth_embs, + depth_preds=depth_preds, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + pil_images: Optional[List[object]] = None, + gen_mask: Optional[torch.FloatTensor] = None, + seg_mask: Optional[torch.FloatTensor] = None, + depth_mask: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + image_sizes + ) + + return self._forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pil_images=pil_images, + gen_mask=gen_mask, + seg_mask=seg_mask, + depth_mask=depth_mask, + ) + +AutoConfig.register("ola_phi3", OlaLlavaPhi3Config) +AutoModelForCausalLM.register(OlaLlavaPhi3Config, OlaLlavaPhi3ForCausalLM) \ No newline at end of file diff --git a/ola_vlm/model/language_model/probe_llava_llama.py b/ola_vlm/model/language_model/probe_llava_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff0450ca28e29ab49a1d543d5a3ea04acffe561 --- /dev/null +++ b/ola_vlm/model/language_model/probe_llava_llama.py @@ -0,0 +1,236 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from transformers import AutoConfig, AutoModelForCausalLM +from transformers import LlamaConfig, LlamaModel, LlamaForCausalLM + + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput +from dataclasses import dataclass +from ..ola_arch import OlaLlavaMetaModel, OlaLlavaMetaForCausalLM + +from ola_vlm.model.aux_heads import GenHead, DepthHead, DAv2_Head +from ola_vlm.model.aux_heads.depth_anything_v2.dpt import DepthAnythingV2 +from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead, OneFormerSegHead + +from transformers import OneFormerProcessor + +from diffusers import ( + DPMSolverMultistepScheduler, + StableUnCLIPImg2ImgPipeline, +) + +import torch.distributed as dist +try: + import wandb +except: + pass +import os +import matplotlib + +from ola_vlm.model.language_model.base_probe_vlm import BaseProbe_VLM + +@dataclass +class ProbeDSGCausalLLMOutputWithPast(CausalLMOutputWithPast): + image_embs: Optional[Tuple[torch.FloatTensor]] = None + seg_embs: Optional[Tuple[torch.FloatTensor]] = None + depth_embs: Optional[Tuple[torch.FloatTensor]] = None + depth_preds: Optional[Tuple[torch.FloatTensor]] = None + + +class ProbeDSGLlavaLlamaConfig(LlamaConfig): + model_type = "probe_dsg_llava_llama" + + +class ProbeDSGLlavaLlamaModel(OlaLlavaMetaModel, LlamaModel): + config_class = ProbeDSGLlavaLlamaConfig + + def __init__(self, config: LlamaConfig): + super(ProbeDSGLlavaLlamaModel, self).__init__(config) + + +class ProbeDSGLlavaLlamaForCausalLM(LlamaForCausalLM, OlaLlavaMetaForCausalLM, BaseProbe_VLM): + config_class = ProbeDSGLlavaLlamaConfig + + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + config.rope_scaling = None + self.model = ProbeDSGLlavaLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def _forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pil_images = None, + + ) -> Union[Tuple, ProbeDSGCausalLLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + with torch.no_grad(): + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + layer_states = outputs[-1][1:] + + logits = self.lm_head(hidden_states) + logits = logits.float() + + log_dict = {} + + seg_embs, seg_loss, seg_log_dict = self.seg_emb_forward(pil_images, hidden_states, layer_states) + if self.mode == "seg": + if pil_images is not None: + if dist.get_rank() == 0: + log_dict = { + **log_dict, + **seg_log_dict + } + + depth_preds, depth_embs, depth_loss, depth_log_dict = self.depth_emb_forward(pil_images, layer_states) + if self.mode == "depth" and hidden_states.shape[1] > 1: + if dist.get_rank() == 0: + log_dict = { + **log_dict, + **depth_log_dict + } + + img_embs, gen_loss, log_dict = self.gen_emb_forward(pil_images, hidden_states, layer_states) + if self.mode == "gen" and hidden_states.shape[1] > 1: + if dist.get_rank() == 0: + log_dict = { + **log_dict, + **depth_log_dict + } + + loss = seg_loss + depth_loss + gen_loss + + try: + if dist.get_rank() == 0: + log_dict = { + **log_dict, + "depth_loss": depth_loss, + "gen_loss": gen_loss, + "seg_loss": seg_loss, + } + filtered_log_dict = {key: value for key, value in log_dict.items() if value > 0} + wandb.log(filtered_log_dict) + self.steps += 1 + except: + pass + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return ProbeDSGCausalLLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_embs=img_embs, + seg_embs=seg_embs, + depth_embs=depth_embs, + depth_preds=depth_preds, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + pil_images: Optional[List[object]] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + _ + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + image_sizes + ) + + return self._forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pil_images=pil_images + ) + +AutoConfig.register("probe_dsg_llava_llama", ProbeDSGLlavaLlamaConfig) +AutoModelForCausalLM.register(ProbeDSGLlavaLlamaConfig, ProbeDSGLlavaLlamaForCausalLM) \ No newline at end of file diff --git a/ola_vlm/model/llava_arch.py b/ola_vlm/model/llava_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..9083efc926138e758a32ecd5ddfa53e30aae61e1 --- /dev/null +++ b/ola_vlm/model/llava_arch.py @@ -0,0 +1,530 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from .multimodal_encoder.builder import build_vision_tower +from .multimodal_projector.builder import build_vision_projector + +from ola_vlm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + +from ola_vlm.mm_utils import get_anyres_image_grid_shape +import numpy as np + +def build_mlp(in_hidden_size, hidden_size): + modules = [nn.Linear(in_hidden_size, hidden_size)] + modules.append(nn.GELU()) + modules.append(nn.Linear(hidden_size, hidden_size)) + return nn.Sequential(*modules) + +class LlavaMetaModel: + + def __init__(self, config): + super(LlavaMetaModel, self).__init__(config) + self.num_task_tokens = 0 + + if hasattr(config, "mm_vision_tower"): + self.vision_tower = build_vision_tower(config, delay_load=False) + self.mm_projector = build_vision_projector(config) + + if 'unpad' in getattr(config, 'mm_patch_merge_type', ''): + self.image_newline = nn.Parameter( + torch.empty(config.hidden_size, dtype=self.dtype) + ) + + if hasattr(config, 'num_task_tokens'): + self.initialize_special_tokens(config) + + def get_vision_tower(self): + vision_tower = getattr(self, 'vision_tower', None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def get_task_mlp(self): + mlp_task = getattr(self, 'mlp_task', None) + return mlp_task + + def get_special_tokens(self): + depth_tokens = getattr(self, 'special_depth_tokens', None) + seg_tokens = getattr(self, 'special_seg_tokens', None) + gen_tokens = getattr(self, 'special_gen_tokens', None) + return depth_tokens, seg_tokens, gen_tokens + + def initialize_special_tokens(self, config): + self.num_task_tokens = config.num_task_tokens + task_token_format = getattr(config, "task_token_format", "emb") + self.task_token_format = task_token_format + self.is_sample_tokens = getattr(config, "sample_tokens", False) + self.aux_tokens = config.aux_mode + self.token_order = config.aux_mode.split("-") + if self.num_task_tokens > 0: + if "depth" in config.aux_mode: + self.special_depth_tokens = nn.Parameter( + torch.randn( + config.image_depth["num_tokens"], config.hidden_size + ) + ) + if "seg" in config.aux_mode: + self.special_seg_tokens = nn.Parameter( + torch.randn( + config.image_seg["num_tokens"], config.hidden_size + ) + ) + if "gen" in config.aux_mode: + self.special_gen_tokens = nn.Parameter( + torch.randn( + config.num_task_tokens, config.hidden_size + ) + ) + + def initialize_vision_modules(self, model_args, fsdp=None): + vision_tower = model_args.vision_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + mm_patch_merge_type = model_args.mm_patch_merge_type + + self.config.mm_vision_tower = vision_tower + + if self.get_vision_tower() is None: + vision_tower = build_vision_tower(model_args) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + else: + self.vision_tower = vision_tower + else: + if fsdp is not None and len(fsdp) > 0: + vision_tower = self.vision_tower[0] + else: + vision_tower = self.vision_tower + vision_tower.load_model() + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') + self.config.mm_hidden_size = vision_tower.hidden_size + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + self.config.mm_patch_merge_type = mm_patch_merge_type + + if getattr(self, 'mm_projector', None) is None: + self.mm_projector = build_vision_projector(self.config) + + if 'unpad' in mm_patch_merge_type: + embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) + self.image_newline = nn.Parameter( + torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std + ) + else: + # In case it is frozen by LoRA + for p in self.mm_projector.parameters(): + p.requires_grad = True + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of PIL image (width, height). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding:current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding:current_width - padding] + + return unpadded_tensor + +def unpad_prep_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of PIL image (width, height). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + mode = "height" + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding:current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding:current_width - padding] + mode = "width" + + return unpadded_tensor, mode, padding + + +class LlavaMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + @property + def depth_tokens(self): + return self.get_model().get_special_tokens()[0] + + @property + def seg_tokens(self): + return self.get_model().get_special_tokens()[1] + + @property + def gen_tokens(self): + return self.get_model().get_special_tokens()[2] + + @property + def num_task_tokens(self): + return self.get_model().num_task_tokens + + @property + def task_token_format(self): + return self.get_model().task_token_format + + @property + def aux_tokens(self): + return self.get_model().aux_tokens + + @property + def token_order(self): + return self.get_model().token_order + + @property + def is_sample_tokens(self): + return self.get_model().is_sample_tokens + + def append_special_tokens(self, cur_new_input_embeds, cur_image_features, cur_labels, cur_new_labels): + + if self.num_task_tokens == 0: + return cur_new_input_embeds, cur_new_labels + + def _get_tokens(self, tokens): + if self.task_token_format == "text": + tk_weights = self.get_model().embed_tokens(tokens.to(cur_image_features.device)) + elif self.task_token_format == "emb": + tk_weights = tokens + elif self.task_token_format == "expand_emb": + tk_weights = tokens.view(self.num_task_tokens, tokens.shape[0] // self.num_task_tokens, tokens.shape[1]) + tk_weights = tk_weights.mean(dim=1) + else: + raise ValueError(f"Unexpected task_token_format: {self.task_token_format}") + return tk_weights + + token_types = { + "depth": self.depth_tokens, + "seg": self.seg_tokens, + "gen": self.gen_tokens + } + + for token_type in self.token_order: + tk_weights = None + + if token_type in self.aux_tokens: + if token_type == "depth" and token_types["depth"] is not None: + tk_weights = _get_tokens(self, token_types["depth"]) + elif token_type == "seg" and token_types["seg"] is not None: + tk_weights = _get_tokens(self, token_types["seg"]) + elif token_type == "gen" and token_types["gen"] is not None: + tk_weights = ( + self.get_model().embed_tokens(self.gen_tokens.to(cur_image_features.device)) + if self.task_token_format == "text" else + self.gen_tokens + ) + + if tk_weights is not None: + cur_new_input_embeds.append(tk_weights) + + cur_new_labels.append(torch.full((tk_weights.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + return cur_new_input_embeds, cur_new_labels + + def encode_images(self, images): + image_features = self.get_model().get_vision_tower()(images).to(images.dtype).to(images.device) + image_features = self.get_model().mm_projector(image_features) + return image_features + + def prepare_inputs_labels_for_multimodal( + self, input_ids, position_ids, attention_mask, past_key_values, labels, + images, image_sizes=None + ): + vision_tower = self.get_vision_tower() + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + concat_images = torch.cat([image for image in images], dim=0) + image_features = self.encode_images(concat_images) + split_sizes = [image.shape[0] for image in images] + image_features = torch.split(image_features, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') + image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square') + if mm_patch_merge_type == 'flat': + image_features = [x.flatten(0, 1) for x in image_features] + elif mm_patch_merge_type.startswith('spatial'): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.get_vision_tower().num_patches_per_side + assert height * width == base_image_feature.shape[0] + if image_aspect_ratio == 'anyres': + num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size) + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + else: + raise NotImplementedError + if 'unpad' in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat(( + image_feature, + self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) + ), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + else: + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if 'unpad' in mm_patch_merge_type: + image_feature = torch.cat(( + image_feature, + self.model.image_newline[None].to(image_feature.device) + ), dim=0) + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") + else: + image_features = self.encode_images(images) + + # TODO: image start / end is not implemented here to support pretraining. + if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): + raise NotImplementedError + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + if num_images == 0: + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + cur_image_features = image_features[cur_image_idx] + cur_image_idx += 1 + + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + + cur_new_input_embeds, cur_new_labels = self.append_special_tokens( + cur_new_input_embeds, cur_image_features, + cur_labels, cur_new_labels, + ) + + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": + new_input_embeds_padded.append(torch.cat(( + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), + cur_new_embed + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat(( + cur_new_embed, + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + + def initialize_vision_tokenizer(self, model_args, tokenizer): + if model_args.mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if model_args.mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + elif model_args.mm_use_im_patch_token: + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = False + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False diff --git a/ola_vlm/model/llava_one_arch.py b/ola_vlm/model/llava_one_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..5f161db05e1b9ae34cef6e324c4c9a2573166513 --- /dev/null +++ b/ola_vlm/model/llava_one_arch.py @@ -0,0 +1,598 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod + +import math +import re +import time +import torch +import torch.nn as nn +from ola_vlm.model.llava_one.multimodal_encoder.builder import build_vision_tower +from ola_vlm.model.llava_one.multimodal_resampler.builder import build_vision_resampler +from ola_vlm.model.llava_one.multimodal_projector.builder import build_vision_projector + +from ola_vlm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + +from ola_vlm.mm_utils import get_anyres_image_grid_shape +# from ola_vlm.utils import rank0_print +import random + + +class LlavaMetaModel: + + def __init__(self, config): + super(LlavaMetaModel, self).__init__(config) + + if hasattr(config, "mm_vision_tower"): + delay_load = getattr(config, "delay_load", False) + self.vision_tower = build_vision_tower(config, delay_load=delay_load) + self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower) + self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) + + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype)) + + def get_vision_tower(self): + vision_tower = getattr(self, "vision_tower", None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def initialize_vision_modules(self, model_args, fsdp=None): + vision_tower = model_args.vision_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + mm_patch_merge_type = model_args.mm_patch_merge_type + + self.config.mm_vision_tower = vision_tower + self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "") + + if self.get_vision_tower() is None: + vision_tower = build_vision_tower(model_args) + vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower) + for k, v in vision_resampler.config.items(): + setattr(self.config, k, v) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + self.vision_resampler = [vision_resampler] + else: + self.vision_tower = vision_tower + self.vision_resampler = vision_resampler + else: + if fsdp is not None and len(fsdp) > 0: + vision_resampler = self.vision_resampler[0] + vision_tower = self.vision_tower[0] + else: + vision_resampler = self.vision_resampler + vision_tower = self.vision_tower + vision_tower.load_model() + + # In case it is frozen by LoRA + for p in self.vision_resampler.parameters(): + p.requires_grad = True + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear") + self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size) + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + self.config.mm_patch_merge_type = mm_patch_merge_type + + + if not hasattr(self.config, 'add_faster_video'): + if model_args.add_faster_video: + embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) + self.faster_token = nn.Parameter( + torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std + ) + + if getattr(self, "mm_projector", None) is None: + self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) + + if "unpad" in mm_patch_merge_type: + embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) + self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std) + else: + # In case it is frozen by LoRA + for p in self.mm_projector.parameters(): + p.requires_grad = True + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu") + + def get_w(weights, keyword): + return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k} + + incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector")) + # rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") + incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False) + # rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of the image (height, width). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + # Compute aspect ratios + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + # Determine padding size and direction + if original_aspect_ratio > current_aspect_ratio: + # Padding was added to the height + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + # Padding was added to the width + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +class LlavaMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def get_2dPool(self, image_feature, stride=2): + height = width = self.get_vision_tower().num_patches_per_side + num_frames, num_tokens, num_dim = image_feature.shape + image_feature = image_feature.view(num_frames, height, width, -1) + image_feature = image_feature.permute(0, 3, 1, 2).contiguous() + # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride) + if self.config.mm_spatial_pool_mode == "average": + image_feature = nn.functional.avg_pool2d(image_feature, stride) + elif self.config.mm_spatial_pool_mode == "max": + image_feature = nn.functional.max_pool2d(image_feature, stride) + elif self.config.mm_spatial_pool_mode == "bilinear": + height, width = image_feature.shape[2:] + scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)] + image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear') + + else: + raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}") + image_feature = image_feature.permute(0, 2, 3, 1) + image_feature = image_feature.view(num_frames, -1, num_dim) + return image_feature + + def encode_images(self, images): + image_features = self.get_model().get_vision_tower()(images) + # image_features = self.get_model().vision_resampler(image_features, images=images) + image_features = self.get_model().mm_projector(image_features) + return image_features + + def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None): + videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images) + per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096) + all_videos_or_images_features = [] + all_faster_video_features = [] + cur_mm_spatial_pool_stride = self.config.mm_spatial_pool_stride + + for idx, feat in enumerate(per_videos_or_images_features): + + feat = self.get_model().mm_projector(feat) + faster_video_feature = 0 + slower_img_feat = 0 + if idx in video_idx_in_batch and cur_mm_spatial_pool_stride > 1: + slower_img_feat = self.get_2dPool(feat,cur_mm_spatial_pool_stride) + if self.config.add_faster_video: + cur_mm_spatial_pool_stride = cur_mm_spatial_pool_stride * 2 + faster_video_feature = self.get_2dPool(feat,cur_mm_spatial_pool_stride) + if slower_img_feat != 0: + all_videos_or_images_features.append(slower_img_feat) + else: + all_videos_or_images_features.append(feat) + all_faster_video_features.append(faster_video_feature) + return all_videos_or_images_features,all_faster_video_features + + def add_token_per_grid(self, image_feature): + resize_h = int(math.sqrt(image_feature.shape[1])) + num_frames = image_feature.shape[0] + feature_dim = image_feature.shape[-1] + + image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + if getattr(self.config, "add_faster_video", False): + # import pdb; pdb.set_trace() + # (3584, 832, 14) -> (3584, 64, 13, 14) + image_feature = image_feature.view(feature_dim, num_frames,resize_h, -1) + # (3584, 64, 13, 14) -> (64, 13, 14, 3584) + image_feature = image_feature.permute(1, 2, 3, 0).contiguous() + # (64, 13, 14, 3584) -> (64, 13*14, 3584) + image_feature = image_feature.flatten(1, 2) + # import pdb; pdb.set_trace() + return image_feature + # import pdb; pdb.set_trace() + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + return image_feature + + def add_token_per_frame(self, image_feature): + image_feature = image_feature.permute(2, 0, 1).contiguous() + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.permute(1, 2, 0).contiguous() + return image_feature + + def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None): + vision_tower = self.get_vision_tower() + # rank_print(modalities) + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels, None + + modalities = ["image"] * len(images) + if isinstance(modalities, str): + modalities = [modalities] + + # import pdb; pdb.set_trace() + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + + video_idx_in_batch = [] + for _ in range(len(modalities)): + if modalities[_] == "video": + video_idx_in_batch.append(_) + + images_list = [] + for image in images: + if image.ndim == 4: + images_list.append(image) + else: + images_list.append(image.unsqueeze(0)) + + concat_images = torch.cat([image for image in images_list], dim=0) + split_sizes = [image.shape[0] for image in images_list] + encoded_image_features = self.encode_images(concat_images) + # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + + # This is a list, each element is [num_images, patch * patch, dim] + # rank_print(f"Concat images : {concat_images.shape}") + encoded_image_features = torch.split(encoded_image_features, split_sizes) + image_features = [] + for idx, image_feat in enumerate(encoded_image_features): + if idx in video_idx_in_batch: + image_features.append(self.get_2dPool(image_feat)) + else: + image_features.append(image_feat) + # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}") + # image_features = torch.split(image_features, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + mm_newline_position = getattr(self.config, "mm_newline_position", "one_token") + + if mm_patch_merge_type == "flat": + image_features = [x.flatten(0, 1) for x in image_features] + + elif mm_patch_merge_type.startswith("spatial"): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + # FIXME: now assume the image is square, and split to 2x2 patches + # num_patches = h * w, where h = w = sqrt(num_patches) + # currently image_feature is a tensor of shape (4, num_patches, hidden_size) + # we want to first unflatten it to (2, 2, h, w, hidden_size) + # rank0_print("At least we are reaching here") + # import pdb; pdb.set_trace() + if image_idx in video_idx_in_batch: # video operations + # rank0_print("Video") + if mm_newline_position == "grid": + # Grid-wise + image_feature = self.add_token_per_grid(image_feature) + if getattr(self.config, "add_faster_video", False): + faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx]) + # Add a token for each frame + concat_slow_fater_token = [] + # import pdb; pdb.set_trace() + for _ in range(image_feature.shape[0]): + if _ % self.config.faster_token_stride == 0: + concat_slow_fater_token.append(torch.cat((image_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0)) + else: + concat_slow_fater_token.append(torch.cat((faster_video_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0)) + # import pdb; pdb.set_trace() + image_feature = torch.cat(concat_slow_fater_token) + + # print("!!!!!!!!!!!!") + + new_image_features.append(image_feature) + elif mm_newline_position == "frame": + # Frame-wise + image_feature = self.add_token_per_frame(image_feature) + + new_image_features.append(image_feature.flatten(0, 1)) + + elif mm_newline_position == "one_token": + # one-token + image_feature = image_feature.flatten(0, 1) + if 'unpad' in mm_patch_merge_type: + image_feature = torch.cat(( + image_feature, + self.model.image_newline[None].to(image_feature.device) + ), dim=0) + new_image_features.append(image_feature) + elif mm_newline_position == "no_token": + new_image_features.append(image_feature.flatten(0, 1)) + else: + raise ValueError(f"Unexpected mm_newline_position: {mm_newline_position}") + elif image_feature.shape[0] > 1: # multi patches and multi images operations + # rank0_print("Single-images") + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.get_vision_tower().num_patches_per_side + assert height * width == base_image_feature.shape[0] + + if "anyres_max" in image_aspect_ratio: + matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio) + if matched_anyres_max_num_patches: + max_num_patches = int(matched_anyres_max_num_patches.group(1)) + + if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: + if hasattr(self.get_vision_tower(), "image_size"): + vision_tower_image_size = self.get_vision_tower().image_size + else: + raise ValueError("vision_tower_image_size is not found in the vision tower.") + try: + num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size) + except Exception as e: + print(f"Error: {e}") + num_patch_width, num_patch_height = 2, 2 + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + else: + image_feature = image_feature.view(2, 2, height, width, -1) + + if "maxpool2x2" in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = nn.functional.max_pool2d(image_feature, 2) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches: + unit = image_feature.shape[2] + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + c, h, w = image_feature.shape + times = math.sqrt(h * w / (max_num_patches * unit**2)) + if times > 1.1: + image_feature = image_feature[None] + image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0] + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif "unpad" in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + else: + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + if "nobase" in mm_patch_merge_type: + pass + else: + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + new_image_features.append(image_feature) + else: # single image operations + image_feature = image_feature[0] + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0) + + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") + else: + image_features = self.encode_images(images) + + # TODO: image start / end is not implemented here to support pretraining. + if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): + raise NotImplementedError + # rank_print(f"Total images : {len(image_features)}") + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_image_idx = 0 + # rank_print("Inserting Images embedding") + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + # rank0_print(num_images) + if num_images == 0: + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + try: + cur_image_features = image_features[cur_image_idx] + except IndexError: + cur_image_features = image_features[cur_image_idx - 1] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + # import pdb; pdb.set_trace() + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) + # rank_print("Finishing Inserting") + + new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] + new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] + # TODO: Hard code for control loss spike + # if tokenizer_model_max_length is not None: + # new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] + # new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + # rank0_print("Prepare pos id") + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, "tokenizer_padding_side", "right") == "left": + new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + # rank0_print("tokenizer padding") + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + if getattr(self.config, "use_pos_skipping", False) and self.training: + position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device) + split_position = random.randint(0, new_input_embeds.size(1)) + left_add = random.randint(0, self.config.pos_skipping_range) + right_add = random.randint(left_add, self.config.pos_skipping_range) + position_ids[:, :split_position] += left_add + position_ids[:, split_position:] += right_add + # import pdb; pdb.set_trace() + # rank0_print("Finish preparing") + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, None + + def initialize_vision_tokenizer(self, model_args, tokenizer): + if model_args.mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if model_args.mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu") + embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + elif model_args.mm_use_im_patch_token: + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = False + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False \ No newline at end of file diff --git a/ola_vlm/model/make_delta.py b/ola_vlm/model/make_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..3bca3e2638ffea8214787bfe36dfe06fb32ed775 --- /dev/null +++ b/ola_vlm/model/make_delta.py @@ -0,0 +1,52 @@ +""" +Usage: +python3 -m ola_vlm.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM +from ola_vlm.model.utils import auto_upgrade + + +def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): + print("Loading base model") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Loading target model") + auto_upgrade(target_model_path) + target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Calculating delta") + for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): + if name not in base.state_dict(): + assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' + continue + if param.data.shape == base.state_dict()[name].shape: + param.data -= base.state_dict()[name] + else: + assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' + bparam = base.state_dict()[name] + param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam + + print("Saving delta") + if hub_repo_id: + kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} + else: + kwargs = {} + target.save_pretrained(delta_path, **kwargs) + target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) + target_tokenizer.save_pretrained(delta_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str, default=None) + args = parser.parse_args() + + make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) diff --git a/ola_vlm/model/multi_enc_llava_arch.py b/ola_vlm/model/multi_enc_llava_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..7e6fd76dce43b55a12fd5d0aac1eeec416f148e6 --- /dev/null +++ b/ola_vlm/model/multi_enc_llava_arch.py @@ -0,0 +1,616 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from .multimodal_encoder.builder import build_vision_tower +from .multimodal_projector.builder import build_vision_projector + +from ola_vlm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + +from ola_vlm.mm_utils import get_anyres_image_grid_shape +import numpy as np + +from ola_vlm.model.aux_heads.sam_utils.build_sam import sam_model_registry +from ola_vlm.model.aux_heads.sam_utils.automatic_mask_generator import SamAutomaticMaskGenerator +from ola_vlm.model.aux_heads.depth_anything_v2.dpt import DepthAnythingV2 +from diffusers import StableUnCLIPImg2ImgPipeline +import torch.nn.functional as F +import copy + +from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead, OneFormerSegHead, OneFormerTaskTokenSegHead +from transformers import OneFormerProcessor, OneFormerConfig + +# import torch +from torchvision import transforms +from PIL import Image + + + +def build_mlp(in_hidden_size, hidden_size): + modules = [nn.Linear(in_hidden_size, hidden_size)] + modules.append(nn.GELU()) + modules.append(nn.Linear(hidden_size, hidden_size)) + return nn.Sequential(*modules) + +model_configs = { + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, + 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, + 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, + 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} +} + +class MultiEncLlavaMetaModel: + + def __init__(self, config): + super(MultiEncLlavaMetaModel, self).__init__(config) + self.attn_mask_type = 'causal' + + if hasattr(config, "mm_vision_tower"): + self.vision_tower = build_vision_tower(config, delay_load=False) + self.mm_projector = build_vision_projector(config) + + if 'unpad' in getattr(config, 'mm_patch_merge_type', ''): + self.image_newline = nn.Parameter( + torch.empty(config.hidden_size, dtype=self.dtype) + ) + + self.aggr = getattr(config, 'aggregation', "features") + if self.aggr == "tokens": + depth_config = copy.deepcopy(config) + depth_config.mm_hidden_size = config.depth_dim + self.depth_projector = build_vision_projector(depth_config) + + gen_config = copy.deepcopy(config) + gen_config.mm_hidden_size = config.gen_dim + self.gen_projector = build_vision_projector(gen_config) + + seg_config = copy.deepcopy(config) + seg_config.mm_hidden_size = config.seg_dim + self.seg_projector = build_vision_projector(seg_config) + + self.init_encoders(config) + + self.set_attn_mask_type(config) + + def init_encoders(self, config): + encoder = 'vitl' # or 'vits', 'vitb', 'vitg' + self.dav2_model = DepthAnythingV2(**model_configs[encoder]) + self.dav2_model.load_state_dict(torch.load(config.depth_estimator, map_location='cpu')) + self.dav2_model.eval() + + self.aggr = getattr(config, 'aggregation', "features") + + try: + self.dav2_model = self.dav2_model.cuda() + except: + pass + + self.pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(config.image_generator, torch_dtype=torch.float16, variant="fp16") + + self.seg_teacher = getattr(config, "seg_teacher", "oneformer") + if self.seg_teacher == "sam": + self.sam = sam_model_registry["vit_l"](checkpoint=self.config.image_segmentor) + try: + self.sam = self.sam.to("cuda") + except: + pass + for p in self.sam.parameters(): + p.requires_grad = False + self.mask_generator = SamAutomaticMaskGenerator(self.sam) + + elif self.seg_teacher == "oneformer": + self.oneformer_processor = OneFormerProcessor.from_pretrained(config.image_segmentor) + self.oneformer = OneFormerHead.from_pretrained(config.image_segmentor) + for p in self.oneformer.parameters(): + p.requires_grad = False + try: + self.oneformer = self.oneformer.to("cuda") + except: + pass + self.mask_generator = None + + def set_attn_mask_type(self, config): + if hasattr(config, 'attn_mask_type'): + self.attn_mask_type = config.attn_mask_type + else: + self.attn_mask_type = 'causal' + print(f"Setting attn_mask_type to {self.attn_mask_type}") + + def get_vision_tower(self): + vision_tower = getattr(self, 'vision_tower', None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def initialize_vision_modules(self, model_args, fsdp=None): + vision_tower = model_args.vision_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + mm_patch_merge_type = model_args.mm_patch_merge_type + + self.config.mm_vision_tower = vision_tower + + if self.get_vision_tower() is None: + vision_tower = build_vision_tower(model_args) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + else: + self.vision_tower = vision_tower + else: + if fsdp is not None and len(fsdp) > 0: + vision_tower = self.vision_tower[0] + else: + vision_tower = self.vision_tower + vision_tower.load_model() + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') + self.config.mm_hidden_size = vision_tower.hidden_size + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + self.config.mm_patch_merge_type = mm_patch_merge_type + + if getattr(self, 'mm_projector', None) is None: + if getattr(model_args, 'aggregation', "features") == "features": + self.config.mm_hidden_size = self.config.mm_hidden_size + model_args.depth_dim + model_args.seg_dim + model_args.gen_dim + self.mm_projector = build_vision_projector(self.config) + + if getattr(model_args, 'aggregation', "features") == "tokens": + depth_config = copy.deepcopy(self.config) + depth_config.mm_hidden_size = model_args.depth_dim + self.depth_projector = build_vision_projector(depth_config) + + gen_config = copy.deepcopy(self.config) + gen_config.mm_hidden_size = model_args.gen_dim + self.gen_projector = build_vision_projector(gen_config) + + seg_config = copy.deepcopy(self.config) + seg_config.mm_hidden_size = model_args.seg_dim + self.seg_projector = build_vision_projector(seg_config) + + if 'unpad' in mm_patch_merge_type: + embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) + self.image_newline = nn.Parameter( + torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std + ) + else: + # In case it is frozen by LoRA + for p in self.mm_projector.parameters(): + p.requires_grad = True + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of PIL image (width, height). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding:current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding:current_width - padding] + + return unpadded_tensor + +def unpad_prep_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of PIL image (width, height). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + mode = "height" + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding:current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding:current_width - padding] + mode = "width" + + return unpadded_tensor, mode, padding + + +def reverse_convnext_preprocess(preprocessed_tensor): + unnormalize = transforms.Normalize(mean=[-0.5/0.5, -0.5/0.5, -0.5/0.5], std=[1/0.5, 1/0.5, 1/0.5]) + image_tensor = torch.clamp(unnormalize(preprocessed_tensor), 0, 1) + return transforms.ToPILImage()(image_tensor) + +class MultiEncLlavaMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + @property + def attn_mask_type(self): + return self.get_model().attn_mask_type + + def get_seg_targets(self, pil_images, preds): + def _get_feats(img, mask_generator): + if self.get_model().seg_teacher == "oneformer": + img = img.resize((768, 768)) + inputs = self.get_model().oneformer_processor(img, ["panoptic"], return_tensors="pt") + self.get_model().oneformer = self.get_model().oneformer.to(preds.device, preds.dtype) + inputs["pixel_values"] = inputs["pixel_values"].to(preds.device, preds.dtype) + with torch.no_grad(): + feats = self.get_model().oneformer.forward_features(**inputs) + else: + img = np.array(img) + mask_generator.predictor.set_image(img, dtype=preds.dtype) + feats = mask_generator.predictor.features + mask_generator.predictor.reset_image() + feats = F.interpolate(feats, (24, 24), mode="bicubic", align_corners=False) + feats = feats.permute(0, 2, 3, 1) + feats = feats.reshape(1, -1, feats.shape[-1]) + return feats + + seg_targets = [] + for img in pil_images: + feat = _get_feats(img, self.get_model().mask_generator) + seg_targets.append(feat) + + seg_targets = torch.stack(seg_targets, dim=0).squeeze(1) + return seg_targets + + def get_dav2_feats(self, pil_images, device): + self.get_model().dav2_model = self.get_model().dav2_model.to(device) + dav2_feats = [] + for img in pil_images: + img = img.resize((336, 336)) + img = np.array(img) + feat = self.get_model().dav2_model.infer_image(img, is_dsg=True) + feat = (feat[0][0] + feat[1][0] + feat[2][0] + feat[3][0]) / 4 + dav2_feats.append(feat.to(device)) + + dav2_feats = torch.stack(dav2_feats, dim=0).squeeze(1) + return dav2_feats + + def get_gen_feats(self, pil_images, device): + gen_feats = [] + self.get_model().pipe.image_encoder = self.get_model().pipe.image_encoder.to(device) + for img in pil_images: + clip_ims = self.get_model().pipe.feature_extractor(images=img, return_tensors="pt").pixel_values.to(device) + feat = self.get_model().pipe.image_encoder(clip_ims).image_embeds + gen_feats.append(feat) + + gen_feats = torch.stack(gen_feats, dim=0) + return gen_feats + + def encode_images(self, images): + image_features = self.get_model().get_vision_tower()(images).to(images.dtype).to(images.device) + + if self.get_model().aggr == "tokens": + image_features = self.get_model().mm_projector(image_features) + + pil_images = [reverse_convnext_preprocess(images[i].float()) for i in range(images.shape[0])] + + depth_feats = self.get_dav2_feats(pil_images, image_features.device).to(image_features.dtype) + + if self.get_model().aggr == "tokens": + depth_feats = depth_feats.permute(0, 2, 1) + depth_feats = F.avg_pool1d(depth_feats, kernel_size=72) + depth_feats = depth_feats.permute(0, 2, 1) + depth_feats = self.get_model().depth_projector(depth_feats) + + gen_feats = self.get_gen_feats(pil_images, image_features.device).to(image_features.dtype) + + if self.get_model().aggr == "tokens": + gen_feats = gen_feats.repeat(1, 8, 1) + gen_feats = self.get_model().gen_projector(gen_feats) + else: + gen_feats = gen_feats.repeat(1, image_features.shape[1], 1) + + seg_feats = self.get_seg_targets(pil_images, image_features).to(image_features.dtype) + + if self.get_model().aggr == "tokens": + seg_feats = seg_feats.permute(0, 2, 1) + seg_feats = F.avg_pool1d(seg_feats, kernel_size=72) + seg_feats = seg_feats.permute(0, 2, 1) + seg_feats = self.get_model().seg_projector(seg_feats) + + if self.get_model().aggr == "tokens": + # image_features = torch.cat([image_features, depth_feats, seg_feats, gen_feats], dim=1) + image_features = torch.cat([image_features, gen_feats, depth_feats, seg_feats], dim=1) + else: + # image_features = torch.cat([image_features, depth_feats, seg_feats, gen_feats], dim=2) + image_features = torch.cat([image_features, gen_feats, depth_feats, seg_feats], dim=2) + image_features = self.get_model().mm_projector(image_features) + + return image_features + + def prepare_inputs_labels_for_multimodal( + self, input_ids, position_ids, attention_mask, past_key_values, labels, + images, image_sizes=None + ): + vision_tower = self.get_vision_tower() + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels, None + + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + concat_images = torch.cat([image for image in images], dim=0) + image_features = self.encode_images(concat_images) + split_sizes = [image.shape[0] for image in images] + image_features = torch.split(image_features, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') + image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square') + if mm_patch_merge_type == 'flat': + image_features = [x.flatten(0, 1) for x in image_features] + elif mm_patch_merge_type.startswith('spatial'): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.get_vision_tower().num_patches_per_side + assert height * width == base_image_feature.shape[0] + if image_aspect_ratio == 'anyres': + num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size) + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + else: + raise NotImplementedError + if 'unpad' in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat(( + image_feature, + self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) + ), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + else: + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if 'unpad' in mm_patch_merge_type: + image_feature = torch.cat(( + image_feature, + self.model.image_newline[None].to(image_feature.device) + ), dim=0) + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") + else: + image_features = self.encode_images(images) + + # TODO: image start / end is not implemented here to support pretraining. + if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): + raise NotImplementedError + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + do_sample = False + else: + do_sample = True + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + block_indices = [None] * len(input_ids) + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + if num_images == 0: + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + num_tokens = 0 + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + num_tokens += cur_input_embeds_no_im[i].shape[0] + cur_image_features = image_features[cur_image_idx] + cur_image_idx += 1 + + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + num_tokens += cur_image_features.shape[0] + + if self.attn_mask_type == "block-causal": + indices = ["block-causal", image_token_indices[1], num_tokens] + block_indices[batch_idx] = indices + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": + new_input_embeds_padded.append(torch.cat(( + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), + cur_new_embed + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat(( + cur_new_embed, + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, block_indices + + def initialize_vision_tokenizer(self, model_args, tokenizer): + if model_args.mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if model_args.mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + elif model_args.mm_use_im_patch_token: + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = False + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False \ No newline at end of file diff --git a/ola_vlm/model/multimodal_encoder/base_encoder.py b/ola_vlm/model/multimodal_encoder/base_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7f73e8f829448ccbe213685fc46e1d3010614348 --- /dev/null +++ b/ola_vlm/model/multimodal_encoder/base_encoder.py @@ -0,0 +1,140 @@ +from abc import abstractmethod + +import torch +import torch.nn as nn +import logging as logger + + +class ProcessorWrapper: + def __init__(self, transform, height=378, width=378, image_mean = [0.48145466, 0.4578275, 0.40821073]): + self._crop_size = { + "height": height, + "width": width, + } + self._transforms = transform + #print(transform) + self.image_mean = image_mean + + @property + def crop_size(self): + return self._crop_size + + def __call__(self, image, return_tensors='pt'): + # Ensure image is a PIL Image + if isinstance(image, list): + image = image[0] + output = {} + output['pixel_values'] = [self._transforms(image)] + return output + + def preprocess(self, image, return_tensors='pt'): + # Ensure image is a PIL Image + if isinstance(image, list): + image = image[0] + output = {} + output['pixel_values'] = [self._transforms(image)] + return output + + +class BaseVisionTower(nn.Module): + def __init__(self, vision_tower_name, args, delay_load=False): + super(BaseVisionTower, self).__init__() + + self.is_loaded = False + self.args = args + + self.vision_tower_name = vision_tower_name + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') + self.unfreeze_mm_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False) + logger.warning(f"Unfreezing MM Vision Tower: {self.unfreeze_mm_vision_tower}") + self.delay_load = delay_load + + @abstractmethod + def load_model(self, device_map=None): + raise NotImplementedError("Subclasses must implement load_model") + + @abstractmethod + def _forward(self, images): + raise NotImplementedError("Subclasses must implement forward") + + def forward(self, images): + if type(images) is list: + image_features = [ + self._forward(image.unsqueeze(0)) + for image in images + ] + else: + image_features = self._forward(images) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + # Dynamically infer the dtype from the first parameter, if not explicitly specified + if hasattr(self.vision_tower, 'dtype'): + return self.vision_tower.dtype + else: + params = list(self.vision_tower.parameters()) + return params[0].dtype if len(params) > 0 else torch.float32 # Default to torch.float32 if no parameters + + @property + def device(self): + # Dynamically infer the device from the first parameter, if not explicitly specified + if hasattr(self.vision_tower, 'device'): + return self.vision_tower.device + else: + params = list(self.vision_tower.parameters()) + return params[0].device if len(params) > 0 else torch.device("cpu") # Default to CPU if no parameters + + @property + def config(self): + if self.is_loaded: + return self.vision_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + try: + return self.config.hidden_size + except: + return self._hidden_size + + @property + def image_size(self): # resolution + # return self.config.image_size + try: + return self.config.image_size + except: + return self._image_size + + @property + def patch_size(self): + # return self.config.patch_size + try: + return self.config.patch_size + except: + return self._patch_size + + @property + def num_patches_per_side(self): + if self._interp_size is not None: + return int(self._interp_size**0.5) + try: + return self.image_size // self.patch_size + except: + return self._num_patches_per_side + + @property + def num_patches(self): + if self._interp_size is not None: + return self._interp_size + try: + return self.num_patches_per_side ** 2 + except: + return self._num_patches \ No newline at end of file diff --git a/ola_vlm/model/multimodal_encoder/builder.py b/ola_vlm/model/multimodal_encoder/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..21f182e8573115810ecf447ee035e0e7289835d8 --- /dev/null +++ b/ola_vlm/model/multimodal_encoder/builder.py @@ -0,0 +1,15 @@ +import os +from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 +from .clip_convnext_encoder import CLIPConvNextVisionTower + + +def build_vision_tower(vision_tower_cfg, **kwargs): + vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) + if "clip" in vision_tower and "convnext" not in vision_tower: + return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + elif "convnext" in vision_tower.lower(): + return CLIPConvNextVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + elif "sam" in vision_tower.lower(): + return SAMVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + + raise ValueError(f'Unknown vision tower: {vision_tower}') diff --git a/ola_vlm/model/multimodal_encoder/clip_convnext_encoder.py b/ola_vlm/model/multimodal_encoder/clip_convnext_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4e8dc40200b701697b89298cd6aca08febfb60 --- /dev/null +++ b/ola_vlm/model/multimodal_encoder/clip_convnext_encoder.py @@ -0,0 +1,206 @@ +import torch +import torch.nn as nn +from ola_vlm.model.multimodal_encoder.openclip_utils import create_model_from_pretrained +from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower +from timm.models.convnext import ConvNeXt +import torch +from torch import nn +import torch.nn.functional as F +from .base_encoder import BaseVisionTower, ProcessorWrapper +from typing import Optional + +class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + drop_path: bool = False, + ): + super().__init__() + self.output_dict = output_dict + + # Fix drop path during training + if not drop_path: + print('Not using drop path during training.') + vision_cfg['timm_drop_path'] = 0.0 + + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + +def extract_res_interp(model_name): + valid_model_prefixes = { + "CLIP-convnext_large":"hf-hub:laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup", + "CLIP-convnext_xxlarge":"hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup" + } + + + res = None + interp = None + + for prefix in valid_model_prefixes: + if model_name.split("/")[-1].startswith(prefix): + base_model_name = valid_model_prefixes[prefix] + break + else: + raise ValueError(f"Unknown vision tower: {model_name}") + + parts = model_name.split("-") + for part in parts: + if part.startswith("res"): + res = int(part[3:]) + elif part.startswith("interp"): + interp = int(part[6:]) + return base_model_name, res, interp + + +class CLIPConvNextVisionTower(BaseVisionTower): + def __init__(self, vision_tower, args, delay_load=False): + """ + Initialize the CLIPConvNextTower. + + Args: + vision_tower (str): The name of the vision tower model in the format "clip-convnext-resXXX-interpYYY". + args (argparse.Namespace): The arguments parsed from the command line. + delay_load (bool, optional): Whether to delay loading the model. Defaults to False. + """ + super().__init__(vision_tower, args, delay_load) + + self.is_multi_stage = "multi-stage" in vision_tower + base_model_name, res, interp = extract_res_interp(vision_tower) + self.vision_tower_name = base_model_name + self.ckpt_path = vision_tower.split("-res")[0] + self._image_size = res if res is not None else 768 + self._interp_size = interp + self._reduction = 32 + + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') + self.unfreeze_mm_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False) + self.is_loaded = False + + if not delay_load: + self.load_model() + elif self.unfreeze_mm_vision_tower: + self.load_model() + else: + assert "CLIP-convnext_large" in vision_tower or "CLIP-convnext_xxlarge" in vision_tower + if "CLIP-convnext_large" in vision_tower: + if "multi-stage" in vision_tower: + self._hidden_size = sum([192, 384, 768, 1536]) + else: + self._hidden_size = 1536 + else: + if "multi-stage" in vision_tower: + self._hidden_size = sum([384, 768, 1536, 3072]) + else: + self._hidden_size = 3072 + + + def load_model(self, device_map=None): + """ + Load the CLIP-ConvNext model. + """ + + assert "clip-convnext" in self.vision_tower_name.lower() + self.vision_model = "convnext" + try: + clip_model, processor = create_model_from_pretrained(self.vision_tower_name, load_ckpt=True) + except: + clip_model, processor = create_model_from_pretrained(self.vision_tower_name, load_ckpt=False) + processor.transforms[0].size = self._image_size + processor.transforms[1].size = (self._image_size, self._image_size) + self.image_processor = ProcessorWrapper(processor, height=self._image_size, width=self._image_size) + + self.vision_tower: ConvNeXt = clip_model.visual.trunk + self.vision_tower.output_tokens = True + feature_info = self.vision_tower.feature_info + if self.is_multi_stage: + self._hidden_size = sum([stage['num_chs'] for stage in feature_info]) + else: + self._hidden_size = feature_info[-1]['num_chs'] + self.is_loaded = True + + def interpolate(self, image_forward_outs): + """ + Interpolate the image features to the desired number of patches. + + Args: + image_forward_outs (torch.Tensor): The output features from the vision tower. + + Returns: + torch.Tensor: The interpolated image features. + """ + if self._interp_size is None: + return image_forward_outs + + image_features = F.interpolate( + image_forward_outs.float(), + size=(self.num_patches_per_side, self.num_patches_per_side), + mode='bilinear', + align_corners=False + ).to(dtype=image_forward_outs.dtype) + image_features = image_features.flatten(2, 3).permute(0, 2, 1).contiguous() + return image_features + + def _forward(self, images): + """ + Perform the forward pass of the CLIPConvNextTower. + + Args: + images (torch.Tensor): The input images. + + Returns: + torch.Tensor: The output features from the vision tower after interpolation. + """ + image_features_stages = [] + x = self.vision_tower.stem(images.to(device=self.device, dtype=self.dtype)) + for stage in self.vision_tower.stages: + x = stage(x) + image_features_stages.append(x) + image_features = self.vision_tower.norm_pre(x).contiguous() + # if not self.is_multi_stage: + # image_features_stages = image_features_stages[-1:] + # image_features_stages_rescaled = [] + # for image_features_single_stage in image_features_stages: + # image_features_single_stage_rescaled = self.interpolate(image_features_single_stage) + # image_features_stages_rescaled.append(image_features_single_stage_rescaled) + # image_features = torch.cat(image_features_stages_rescaled, -1) + image_features = image_features.flatten(2, 3).permute(0, 2, 1).contiguous() + return image_features + + @property + def image_size(self): + return self._image_size + + @property + def num_patches_per_side(self): + """ + Get the number of patches per side. + + Returns: + int: The number of patches per side. + """ + if self._interp_size is None: + return self._image_size // self._reduction + else: + return int(self._interp_size ** 0.5) + + @property + def num_patches(self): + """ + Get the total number of patches. + + Default: 256 + + Returns: + int: The total number of patches. + """ + if self._interp_size is None: + return (self._image_size // self._reduction) ** 2 + else: + return self._interp_size \ No newline at end of file diff --git a/ola_vlm/model/multimodal_encoder/clip_encoder.py b/ola_vlm/model/multimodal_encoder/clip_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..97f93701c1c1c117d9e537a3529252a95885e434 --- /dev/null +++ b/ola_vlm/model/multimodal_encoder/clip_encoder.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn +import os +from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig + + +class CLIPVisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + if not os.path.exists(self.vision_tower_name) and "clip-vit-large-patch14-336" in self.vision_tower_name: + self.vision_tower_name = "openai/clip-vit-large-patch14-336" + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') + + if not delay_load: + self.load_model() + elif getattr(args, 'unfreeze_mm_vision_tower', False): + self.load_model() + else: + self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) + + def load_model(self, device_map=None): + if self.is_loaded: + print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + image_features = image_forward_outs.hidden_states[self.select_layer] + if self.select_feature == 'patch': + image_features = image_features[:, 1:] + elif self.select_feature == 'cls_patch': + image_features = image_features + else: + raise ValueError(f'Unexpected select feature: {self.select_feature}') + return image_features + + @torch.no_grad() + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def config(self): + if self.is_loaded: + return self.vision_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + return self.config.hidden_size + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + + @property + def num_patches(self): + return (self.config.image_size // self.config.patch_size) ** 2 + + + +class CLIPVisionTowerS2(CLIPVisionTower): + def __init__(self, vision_tower, args, delay_load=False): + self.s2_scales = getattr(args, 's2_scales', '336,1008') + self.s2_scales = list(map(int, self.s2_scales.split(','))) + self.s2_scales.sort() + self.s2_split_size = self.s2_scales[0] + self.s2_image_size = self.s2_scales[-1] + + try: + from s2wrapper import forward as multiscale_forward + except ImportError: + raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git') + self.multiscale_forward = multiscale_forward + + super().__init__(vision_tower, args, delay_load) + + # change resize/crop size in preprocessing to the largest image size in s2_scale + if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False): + self.image_processor.size['shortest_edge'] = self.s2_image_size + self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size + + def load_model(self, device_map=None): + if self.is_loaded: + print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) + + self.image_processor.size['shortest_edge'] = self.s2_image_size + self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size + + self.is_loaded = True + + @torch.no_grad() + def forward_feature(self, images): + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + return image_features + + @torch.no_grad() + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size) + image_features.append(image_feature) + else: + image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size) + + return image_features + + @property + def hidden_size(self): + return self.config.hidden_size * len(self.s2_scales) diff --git a/ola_vlm/model/multimodal_encoder/openclip_utils.py b/ola_vlm/model/multimodal_encoder/openclip_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e86318305db937af18d2864e18596be51c33fd52 --- /dev/null +++ b/ola_vlm/model/multimodal_encoder/openclip_utils.py @@ -0,0 +1,223 @@ +import logging +import os +from dataclasses import asdict +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from open_clip.model import CLIP, CustomTextCLIP, convert_weights_to_lp,\ + get_cast_dtype, set_model_preprocess_cfg, convert_to_custom_text_state_dict, resize_pos_embed +from open_clip.coca_model import CoCa +from open_clip.openai import load_openai_model +from open_clip.pretrained import get_pretrained_cfg, download_pretrained,\ + list_pretrained_tags_by_model, download_pretrained_from_hf +from open_clip.transform import image_transform_v2, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs +from open_clip.factory import get_model_config, _get_hf_config, list_models, load_state_dict, load_checkpoint +HF_HUB_PREFIX = 'hf-hub:' + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + force_preprocess_cfg: Optional[Dict[str, Any]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + require_pretrained: bool = False, + load_ckpt: Optional[bool] = True, + **model_kwargs, +): + force_preprocess_cfg = force_preprocess_cfg or {} + preprocess_cfg = asdict(PreprocessCfg()) + has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) + if has_hf_hub_prefix: + model_id = model_name[len(HF_HUB_PREFIX):] + checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + config = _get_hf_config(model_id, cache_dir) + preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) + model_cfg = config['model_cfg'] + pretrained_hf = False # override, no need to load original HF text weights + else: + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + checkpoint_path = None + model_cfg = None + + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == 'openai': + logging.info(f'Loading pretrained {model_name} from OpenAI.') + model = load_openai_model( + model_name, + precision=precision, + device=device, + cache_dir=cache_dir, + ) + else: + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') + else: + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) + if pretrained_image: + if is_timm_model: + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + else: + assert False, 'pretrained image towers currently only supported for timm models' + + # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + if is_hf_model: + # load pretrained weights for HF text model IFF no CLIP weights being loaded + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) + if custom_text: + if "multimodal_cfg" in model_cfg: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + if precision in ("fp16", "bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + # manual mixed precision that matches original OpenAI behaviour + if is_timm_model: + # FIXME this is a bit janky, create timm based model in low-precision and + # then cast only LayerNormFp32 instances back to float32 so they don't break. + # Why? The convert_weights_to_lp fn only works with native models. + model.to(dtype=dtype) + from open_clip.transformer import LayerNormFp32 + + def _convert_ln(m): + if isinstance(m, LayerNormFp32): + m.weight.data = m.weight.data.to(torch.float32) + m.bias.data = m.bias.data.to(torch.float32) + model.apply(_convert_ln) + else: + convert_weights_to_lp(model, dtype=dtype) + elif precision in ("pure_fp16", "pure_bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(dtype=dtype) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + if load_ckpt: + load_checkpoint(model, checkpoint_path) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + if load_ckpt: + logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') + load_checkpoint(model, checkpoint_path) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + + if jit: + model = torch.jit.script(model) + + # set image preprocessing configuration in model attributes for convenience + if getattr(model.visual, 'image_size', None) is not None: + # use image_size set on model creation (via config or force_image_size arg) + force_preprocess_cfg['size'] = model.visual.image_size + set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg)) + + return model + + +def create_model_from_pretrained( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + image_interpolation: Optional[str] = None, + image_resize_mode: Optional[str] = None, # only effective for inference + return_transform: bool = True, + cache_dir: Optional[str] = None, + load_ckpt: Optional[bool] = True, + **model_kwargs, +): + force_preprocess_cfg = merge_preprocess_kwargs( + {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) + + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_image_size=force_image_size, + force_preprocess_cfg=force_preprocess_cfg, + cache_dir=cache_dir, + require_pretrained=True, + load_ckpt=load_ckpt, + **model_kwargs, + ) + + if not return_transform: + return model + + preprocess = image_transform_v2( + PreprocessCfg(**model.visual.preprocess_cfg), + is_train=False, + ) + + return model, preprocess \ No newline at end of file diff --git a/ola_vlm/model/multimodal_projector/builder.py b/ola_vlm/model/multimodal_projector/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..d72c45650339eda8608c5581f3186fa9f8fd1d22 --- /dev/null +++ b/ola_vlm/model/multimodal_projector/builder.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +import re +from ola_vlm.model.multimodal_projector.resampler import Resampler + + +class IdentityMap(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_projector_type": 'identity'} + + +class SimpleResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.pre_norm = nn.LayerNorm(channels) + + self.proj = nn.Sequential( + nn.Linear(channels, channels), + nn.GELU(), + nn.Linear(channels, channels) + ) + def forward(self, x): + x = self.pre_norm(x) + return x + self.proj(x) + + +def build_resampler(config, num_queries=None): + return Resampler( + dim=config["probe_output_dim"], + depth=config["probe_depth"], + dim_head=config["probe_dim_head"], + heads=config["probe_num_heads"], + num_queries=config["num_queries"] if num_queries is None else num_queries, + embedding_dim=config.hidden_size, + output_dim=config["probe_output_dim"], + ff_mult=config["probe_ff_mult"], + ) + + +def build_vision_projector(config, delay_load=False, **kwargs): + projector_type = getattr(config, 'mm_projector_type', 'linear') + + if projector_type == 'linear': + return nn.Linear(config.mm_hidden_size, config.hidden_size) + + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + return nn.Sequential(*modules) + + if projector_type == 'identity': + return IdentityMap() + + raise ValueError(f'Unknown projector type: {projector_type}') diff --git a/ola_vlm/model/multimodal_projector/resampler.py b/ola_vlm/model/multimodal_projector/resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..acc19892b9a0427d68b90825d68cf13a698ef248 --- /dev/null +++ b/ola_vlm/model/multimodal_projector/resampler.py @@ -0,0 +1,368 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class AttentionPool2d(nn.Module): + + def __init__(self, seq_len: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(seq_len + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x, return_all_tokens=False): + # x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = x.permute(1, 0, 2) # (N(HW)C) => (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward(query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + if return_all_tokens: + return x + else: + return x[0] + + +class Resampler(nn.Module): + + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.in_dim = dim + self.out_dim = output_dim + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + output_embeds = self.norm_out(latents) + + return output_embeds + +class TaskTokenResampler(nn.Module): + + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + num_queries=8, + heads=16, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + # self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + # self.task_w = nn.Parameter(torch.tensor(0.0)) + self.proj_in = nn.Linear(embedding_dim, dim) + self.num_queries = num_queries + self.dim = dim + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.in_dim = dim + self.out_dim = output_dim + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + def forward(self, x, latents): + + if latents is None: + latents = torch.zeros(x.size[0], self.num_queries, self.dim) + else: + if latents.shape[1] != self.num_queries: + if self.num_queries > 1 and self.num_queries % latents.shape[1] == 0: + n = latents.shape[1] + latents = latents.repeat(1, self.num_queries // n, 1) + else: + latents = latents.mean(dim=1, keepdim=True).repeat(1, self.num_queries, 1) + latents = self.proj_in(latents) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + output_embeds = self.norm_out(latents) + + return output_embeds + + +class ResamplerXL(nn.Module): + + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output1_dim=768, + output2_dim=1280, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + # self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(dim) + + self.in_dim = dim + self.out_dim = output1_dim + output2_dim + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + self.unet_proj_1 = nn.Linear(self.in_dim, output1_dim) + self.unet_proj_2 = nn.Linear(self.in_dim, output2_dim) + self.unet_attnpool = AttentionPool2d(num_queries, self.in_dim, heads, output2_dim) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + hidden_embeds = self.norm_out(latents) + + encoder_hidden_1 = self.unet_proj_1(hidden_embeds) # [bs, 256, 768] + encoder_hidden_2 = self.unet_proj_2(hidden_embeds) # [bs, 256, 1280] + prompt_embeds = torch.cat([encoder_hidden_1, encoder_hidden_2], dim=-1) # [bs, 256, 2048] + pooled_prompt_embeds = self.unet_attnpool(hidden_embeds) # [bs, 1280] + + return prompt_embeds, pooled_prompt_embeds + + +class ResamplerXLV2(nn.Module): + + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output1_dim=768, + output2_dim=1280, + ff_mult=4, + normalize=True + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.normalize = normalize + self.proj_in = nn.Linear(embedding_dim, dim) + + # self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(dim) + + self.in_dim = dim + self.out_dim = output1_dim + output2_dim + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + self.unet_proj_1 = nn.Linear(self.in_dim, output1_dim) + self.unet_proj_2 = nn.Linear(self.in_dim, output2_dim) + self.unet_attnpool = AttentionPool2d(num_queries, self.in_dim, heads, output2_dim) + + def forward(self, x,pooled_text_embeds=None): + + latents = self.latents.repeat(x.size(0), 1, 1) + + if self.normalize: + x = F.normalize(x) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + hidden_embeds = self.norm_out(latents) + + encoder_hidden_1 = self.unet_proj_1(hidden_embeds) # [bs, 256, 768] + encoder_hidden_2 = self.unet_proj_2(hidden_embeds) # [bs, 256, 1280] + prompt_embeds = torch.cat([encoder_hidden_1, encoder_hidden_2], dim=-1) # [bs, 256, 2048] + pooled_prompt_embeds = self.unet_attnpool(hidden_embeds) # [bs, 1280] + + return prompt_embeds, pooled_prompt_embeds + +class ResamplerXLIdentity(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, pooled_text_embeds=None): + return x, pooled_text_embeds + + +if __name__ == '__main__': + image_proj_model = Resampler(dim=1024, + depth=4, + dim_head=64, + heads=12, + num_queries=1024, + embedding_dim=1024, + output_dim=1024, + ff_mult=4) + numel = 0 + for name, param in image_proj_model.named_parameters(): + numel += param.numel() + + print(f'Total params: {numel}') \ No newline at end of file diff --git a/ola_vlm/model/ola_arch.py b/ola_vlm/model/ola_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..91a7b75abfc969fd525bf46f7a3e079d5e094dc4 --- /dev/null +++ b/ola_vlm/model/ola_arch.py @@ -0,0 +1,489 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from .multimodal_encoder.builder import build_vision_tower +from .multimodal_projector.builder import build_vision_projector + +from ola_vlm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + +from ola_vlm.mm_utils import get_anyres_image_grid_shape +import numpy as np + +def build_mlp(in_hidden_size, hidden_size): + modules = [nn.Linear(in_hidden_size, hidden_size)] + modules.append(nn.GELU()) + modules.append(nn.Linear(hidden_size, hidden_size)) + return nn.Sequential(*modules) + +class OlaLlavaMetaModel: + + def __init__(self, config): + super(OlaLlavaMetaModel, self).__init__(config) + self.aux_tokens = 'depth-seg-gen' + self.token_order = ["depth", "seg", "gen"] + self.num_task_tokens = 0 + + if hasattr(config, "mm_vision_tower"): + self.vision_tower = build_vision_tower(config, delay_load=False) + self.mm_projector = build_vision_projector(config) + + if 'unpad' in getattr(config, 'mm_patch_merge_type', ''): + self.image_newline = nn.Parameter( + torch.empty(config.hidden_size, dtype=self.dtype) + ) + + if hasattr(config, 'num_task_tokens') and not hasattr(config, 'probe_mode'): + self.initialize_special_tokens(config) + + def get_vision_tower(self): + vision_tower = getattr(self, 'vision_tower', None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def get_special_tokens(self): + depth_tokens = getattr(self, 'special_depth_tokens', None) + seg_tokens = getattr(self, 'special_seg_tokens', None) + gen_tokens = getattr(self, 'special_gen_tokens', None) + return depth_tokens, seg_tokens, gen_tokens + + def initialize_special_tokens(self, config): + self.num_task_tokens = config.num_task_tokens + task_token_format = getattr(config, "task_token_format", "emb") + self.task_token_format = task_token_format + self.is_sample_tokens = getattr(config, "sample_tokens", False) + self.aux_tokens = config.aux_mode + self.token_order = config.aux_mode.split("-") + if self.num_task_tokens > 0: + if "depth" in config.aux_mode: + assert config.image_depth["num_tokens"] % self.num_task_tokens == 0, f"{config.image_depth['num_tokens']} must be divisible by {self.num_task_tokens}" + self.special_depth_tokens = nn.Parameter( + torch.randn( + config.image_depth["num_tokens"], config.hidden_size + ) + ) + if "seg" in config.aux_mode: + assert config.image_seg["num_tokens"] % self.num_task_tokens == 0, f"{config.image_seg['num_tokens']} must be divisible by {self.num_task_tokens}" + self.special_seg_tokens = nn.Parameter( + torch.randn( + config.image_seg["num_tokens"], config.hidden_size + ) + ) + if "gen" in config.aux_mode: + self.special_gen_tokens = nn.Parameter( + torch.randn( + config.num_task_tokens, config.hidden_size + ) + ) + + def initialize_vision_modules(self, model_args, fsdp=None): + vision_tower = model_args.vision_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + mm_patch_merge_type = model_args.mm_patch_merge_type + + self.config.mm_vision_tower = vision_tower + + if self.get_vision_tower() is None: + vision_tower = build_vision_tower(model_args) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + else: + self.vision_tower = vision_tower + else: + if fsdp is not None and len(fsdp) > 0: + vision_tower = self.vision_tower[0] + else: + vision_tower = self.vision_tower + vision_tower.load_model() + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') + self.config.mm_hidden_size = vision_tower.hidden_size + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + self.config.mm_patch_merge_type = mm_patch_merge_type + + if getattr(self, 'mm_projector', None) is None: + self.mm_projector = build_vision_projector(self.config) + + if 'unpad' in mm_patch_merge_type: + embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) + self.image_newline = nn.Parameter( + torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std + ) + else: + # In case it is frozen by LoRA + for p in self.mm_projector.parameters(): + p.requires_grad = True + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of PIL image (width, height). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding:current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding:current_width - padding] + + return unpadded_tensor + + +class OlaLlavaMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def encode_images(self, images): + image_features = self.get_model().get_vision_tower()(images).to(images.dtype).to(images.device) + image_features = self.get_model().mm_projector(image_features) + return image_features + + @property + def depth_tokens(self): + return self.get_model().get_special_tokens()[0] + + @property + def seg_tokens(self): + return self.get_model().get_special_tokens()[1] + + @property + def gen_tokens(self): + return self.get_model().get_special_tokens()[2] + + @property + def num_task_tokens(self): + return self.get_model().num_task_tokens + + @property + def task_token_format(self): + return self.get_model().task_token_format + + @property + def aux_tokens(self): + return self.get_model().aux_tokens + + @property + def token_order(self): + return self.get_model().token_order + + @property + def is_sample_tokens(self): + return self.get_model().is_sample_tokens + + def append_special_tokens(self, cur_new_input_embeds, cur_image_features, cur_labels, cur_new_labels): + def _get_tokens(self, tokens): + tk_weights = tokens.view(self.num_task_tokens, tokens.shape[0] // self.num_task_tokens, tokens.shape[1]) + tk_weights = tk_weights.mean(dim=1) + return tk_weights + + token_types = { + "depth": self.depth_tokens, + "seg": self.seg_tokens, + "gen": self.gen_tokens + } + + for token_type in self.token_order: + tk_weights = None + + if token_type in self.aux_tokens: + if token_type == "depth" and token_types["depth"] is not None: + tk_weights = _get_tokens(self, token_types["depth"]) + elif token_type == "seg" and token_types["seg"] is not None: + tk_weights = _get_tokens(self, token_types["seg"]) + elif token_type == "gen" and token_types["gen"] is not None: + tk_weights = ( + self.get_model().embed_tokens(self.gen_tokens.to(cur_image_features.device)) + if self.task_token_format == "text" else + self.gen_tokens + ) + + if tk_weights is not None: + cur_new_input_embeds.append(tk_weights) + cur_new_labels.append(torch.full((tk_weights.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + return cur_new_input_embeds, cur_new_labels + + def prepare_inputs_labels_for_multimodal( + self, input_ids, position_ids, attention_mask, past_key_values, labels, + images, image_sizes=None + ): + vision_tower = self.get_vision_tower() + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + concat_images = torch.cat([image for image in images], dim=0) + image_features = self.encode_images(concat_images) + split_sizes = [image.shape[0] for image in images] + image_features = torch.split(image_features, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') + image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square') + if mm_patch_merge_type == 'flat': + image_features = [x.flatten(0, 1) for x in image_features] + elif mm_patch_merge_type.startswith('spatial'): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.get_vision_tower().num_patches_per_side + assert height * width == base_image_feature.shape[0] + if image_aspect_ratio == 'anyres': + num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size) + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + else: + raise NotImplementedError + if 'unpad' in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat(( + image_feature, + self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) + ), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + else: + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if 'unpad' in mm_patch_merge_type: + image_feature = torch.cat(( + image_feature, + self.model.image_newline[None].to(image_feature.device) + ), dim=0) + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") + else: + image_features = self.encode_images(images) + + # TODO: image start / end is not implemented here to support pretraining. + if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): + raise NotImplementedError + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_image_idx = 0 + + + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + + if num_images == 0: + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + cur_image_features = image_features[cur_image_idx] + cur_image_idx += 1 + + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + + cur_new_input_embeds, cur_new_labels = self.append_special_tokens( + cur_new_input_embeds, cur_image_features, + cur_labels, cur_new_labels, + ) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": + new_input_embeds_padded.append(torch.cat(( + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), + cur_new_embed + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat(( + cur_new_embed, + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + + def initialize_vision_tokenizer(self, model_args, tokenizer): + if model_args.mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if model_args.mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + elif model_args.mm_use_im_patch_token: + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = False + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + diff --git a/ola_vlm/model/utils.py b/ola_vlm/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2563f89c6cedf5e73508afec8f9979105df9b745 --- /dev/null +++ b/ola_vlm/model/utils.py @@ -0,0 +1,20 @@ +from transformers import AutoConfig + + +def auto_upgrade(config): + cfg = AutoConfig.from_pretrained(config) + if 'llava' in config and 'llava' not in cfg.model_type: + assert cfg.model_type == 'llama' + print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") + print("You must upgrade the checkpoint to the new code base (this can be done automatically).") + confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") + if confirm.lower() in ["y", "yes"]: + print("Upgrading checkpoint...") + assert len(cfg.architectures) == 1 + setattr(cfg.__class__, "model_type", "llava") + cfg.architectures[0] = 'LlavaLlamaForCausalLM' + cfg.save_pretrained(config) + print("Checkpoint upgraded.") + else: + print("Checkpoint upgrade aborted.") + exit(1) diff --git a/ola_vlm/ola_utils.py b/ola_vlm/ola_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..74f68ef966bbbafc9ebac534450f3f4f12b527e6 --- /dev/null +++ b/ola_vlm/ola_utils.py @@ -0,0 +1,201 @@ +from typing import List, Optional + +import torch +import torch.nn.functional as F +import numpy as np + + +import torch.distributed as dist +from PIL import Image, ImageDraw +import matplotlib.pyplot as plt +import diffdist.functional as diff_dist + +from typing import List, Optional +from torchvision.ops import masks_to_boxes +import io + + +def visualize_oneformer_masks_on_image( + image: torch.Tensor, + masks: List[torch.Tensor], + classes: List[str], + save_path: Optional[str] = None, +): + """ + inputs: + image: torch.Tensor of shape (3, H, W) + masks: List[torch.Tensor] of len NUM_MASKS + classes: List[str] of len NUM_MASKS + save_path: Optional[str] path to save the visualization + returns: + pil_image: PIL.Image with masks overlayed on the image + """ + + def _show_mask(mask, class_name, ax, random_color=False): + mask = mask.cpu() + box = masks_to_boxes(mask.unsqueeze(0))[0] + x0, y0, x1, y1 = box + x = (x0 + x1) / 2 + y = (y0 + y1) / 2 + if random_color: + color = np.concatenate( + [np.random.random(3), np.array([0.6])], axis=0 + ) + else: + color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + ax.text(x, y, class_name, fontsize="x-small") + + # Create a matplotlib figure + fig, ax = plt.subplots() + ax.imshow(np.array(image)) # Convert to HWC format for plt + ax.set_autoscale_on(False) + for mask, class_name in zip(masks, classes): + _show_mask(mask, class_name, ax=ax, random_color=True) + plt.axis("off") + plt.tight_layout() + + # Save figure to a BytesIO object and convert to PIL.Image + buf = io.BytesIO() + plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) + buf.seek(0) + pil_image = Image.open(buf) + + # Optionally save the PIL image + if save_path is not None: + pil_image.save(save_path) + + plt.close(fig) + return pil_image + +def oneformer_prepare_panoptic_instance_prediction( + segmentation: torch.Tensor, segments_info: dict, oneformer +): + masks = [] + classes = [] + + for segment in segments_info: + id = segment["id"] + label_id = segment["label_id"] + label = oneformer.config.id2label[label_id] + mask = segmentation == id + masks.append(mask.float()) + classes.append(label) + + return masks, classes + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + +def dist_collect(x): + """ collect all tensor from all GPUs + args: + x: shape (mini_batch, ...) + returns: + shape (mini_batch * num_gpu, ...) + """ + x = x.contiguous() + out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())] + out_list = diff_dist.all_gather(out_list, x) + return torch.cat(out_list, dim=0).contiguous() + +def calculate_contrastive_loss(preds, targets, logit_scale): + batch_size = preds.shape[0] + if is_dist_avail_and_initialized(): + labels = torch.arange(batch_size, dtype=torch.long, device=preds.device) + batch_size * dist.get_rank() + else: + labels = torch.arange(batch_size, dtype=torch.long, device=preds.device) + + preds = F.normalize(preds.flatten(1), dim=-1) + targets = F.normalize(targets.flatten(1), dim=-1) + + if is_dist_avail_and_initialized(): + logits_per_img = preds @ dist_collect(targets).t() + else: + logits_per_img = preds @ targets.t() + + logit_scale = torch.clamp(logit_scale.exp(), max=100) + loss_contrastive = F.cross_entropy(logits_per_img * logit_scale, labels, reduction="none") + return loss_contrastive + +def silog_loss(depth_est, depth_gt, variance_focus=0.5): + mask = (depth_gt > 0).detach() + if mask.sum() == 0: + return torch.tensor(0.0).to(depth_est) + d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask]) + loss = torch.sqrt(torch.pow(d, 2).mean() - + variance_focus * torch.pow(d.mean(), 2)) * 1.0 + return loss + +def make_grid(images, pil_images): + # Assuming each image is the same size + + new_images = [] + new_captions = [] + for image, pil_image in zip(images, pil_images): + new_images.append(image) + pil_image = pil_image.resize((image.size[0], image.size[1])) + new_images.append(pil_image) + new_captions.append("Predicted") + new_captions.append("GT") + + images = new_images + captions = new_captions + + width, height = images[0].size + font_size = 14 + caption_height = font_size + 10 + + # Calculate the size of the final image + images_per_row = min(len(images), 16) # Round up for odd number of images + row_count = (len(images) + 1) // images_per_row + total_width = width * images_per_row + total_height = (height + caption_height) * row_count + + # Create a new blank image + new_image = Image.new("RGB", (total_width, total_height), "white") + + draw = ImageDraw.Draw(new_image) + + for i, (image, caption) in enumerate(zip(images, captions)): + row = i // images_per_row + col = i % images_per_row + x_offset = col * width + y_offset = row * (height + caption_height) + + new_image.paste(image, (x_offset, y_offset)) + text_position = (x_offset + 10, y_offset + height) + draw.text(text_position, caption, fill="red", font_size=font_size) + + return new_image + +def visualize_masks(anns, rgb_image): + if len(anns) == 0: + return rgb_image + + sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True) + ax = plt.gca() + ax.set_autoscale_on(False) + + img_array = np.array(rgb_image) + masked_image = np.ones(img_array.shape) + + for ann in sorted_anns: + m = ann['segmentation'] + color_mask = np.random.random(3) + + masked_image[m] = (color_mask * 255).astype(np.uint8) + + img_array = img_array * 0.35 + masked_image * 0.65 + + img_array = img_array.astype(np.uint8) + ax.imshow(img_array) + overlayed_img = Image.fromarray(img_array) + + return overlayed_img \ No newline at end of file diff --git a/ola_vlm/train/llama_flash_attn_monkey_patch.py b/ola_vlm/train/llama_flash_attn_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..31db2eff8d1c4b3ae645583dfc5e156e818b6f1c --- /dev/null +++ b/ola_vlm/train/llama_flash_attn_monkey_patch.py @@ -0,0 +1,115 @@ +from typing import Optional, Tuple +import warnings + +import torch + +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +except ImportError: + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) # shape: (b, num_heads, s, head_dim) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + # reuse k, v + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Transform the data into the format required by flash attention + qkv = torch.stack([query_states, key_states, value_states], dim=2) + qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] + key_padding_mask = attention_mask + + if key_padding_mask is None: + qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) + cu_q_lens = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device + ) + max_s = q_len + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = output.view(bsz, q_len, -1) + else: + qkv = qkv.reshape(bsz, q_len, -1) + qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + output_unpad = flash_attn_unpadded_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) + output = pad_input(output_unpad, indices, bsz, q_len) + + return self.o_proj(output), None, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + warnings.warn( + "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." + "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" + ) + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( + _prepare_decoder_attention_mask + ) + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/ola_vlm/train/llama_xformers_attn_monkey_patch.py b/ola_vlm/train/llama_xformers_attn_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..f8351e41ccd4a64dca237bd8f8be0702b23989dc --- /dev/null +++ b/ola_vlm/train/llama_xformers_attn_monkey_patch.py @@ -0,0 +1,129 @@ +""" +Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments +""" + +import logging +import math +from typing import Optional, Tuple + +import torch +import transformers.models.llama.modeling_llama +from torch import nn + +try: + import xformers.ops +except ImportError: + logging.error("xformers not found! Please install it before trying to use it.") + + +def replace_llama_attn_with_xformers_attn(): + transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward + + +def xformers_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # pylint: disable=duplicate-code + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + ( + query_states, + key_states, + ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # We only apply xformers optimizations if we don't need to output the whole attention matrix + if not output_attentions: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. + # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. + if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=None + ) + else: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention( + query_states, + key_states, + value_states, + attn_bias=xformers.ops.LowerTriangularMask(), + ) + attn_weights = None + else: + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value diff --git a/ola_vlm/train/llava_trainer.py b/ola_vlm/train/llava_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..b8490cf7e967982cb206822f10108fb05b98c2cd --- /dev/null +++ b/ola_vlm/train/llava_trainer.py @@ -0,0 +1,1022 @@ +import os +import torch +import torch.nn as nn + +from torch.utils.data import Sampler + +from transformers import Trainer +from transformers.trainer import ( + is_sagemaker_mp_enabled, + get_parameter_names, + has_length, + ALL_LAYERNORM_LAYERS, + logger, + _is_peft_model, +) +from typing import List, Optional + +import math +import os +import shutil +import sys +import time +from typing import List, Optional +TRAINER_STATE_NAME = "trainer_state.json" + +# Integrations must be imported before ML frameworks: +# isort: off +from transformers.integrations import ( + hp_params, +) + +# isort: on + +import torch +import torch.distributed as dist +from packaging import version +from torch import nn +from torch.utils.data import RandomSampler + +from transformers import __version__ +from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, +) +from transformers.debug_utils import DebugOption, DebugUnderflowOverflow +from transformers.trainer_callback import ( + DefaultFlowCallback, + ExportableState, + ProgressCallback, + TrainerState, +) +from transformers.trainer_pt_utils import ( + LengthGroupedSampler, + get_model_param_count, + get_parameter_names, +) +from transformers.trainer_utils import ( + HPSearchBackend, + TrainOutput, + has_length, + speed_metrics, +) +from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments +from transformers.utils import ( + is_accelerate_available, + is_apex_available, + is_datasets_available, + is_sagemaker_mp_enabled, + is_torch_xla_available, +) + +DEFAULT_CALLBACKS = [DefaultFlowCallback] +DEFAULT_PROGRESS_CALLBACK = ProgressCallback + + +if is_apex_available(): + from apex import amp + +if is_datasets_available(): + import datasets + +IS_XLA_FSDPV2_POST_2_2 = False + +IS_SAGEMAKER_MP_POST_1_10 = False + + +if is_accelerate_available(): + from accelerate import Accelerator, skip_first_batches + from accelerate import __version__ as accelerate_version + from accelerate.utils import ( + DistributedType, + ) + + DATA_SAMPLERS = [RandomSampler] + if version.parse(accelerate_version) > version.parse("0.23.0"): + from accelerate.data_loader import SeedableRandomSampler + + DATA_SAMPLERS += [SeedableRandomSampler] + + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + print(name, 'no ignore status') + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} + return to_return + + +def split_to_even_chunks(indices, lengths, num_chunks): + """ + Split a list of indices into `chunks` chunks of roughly equal lengths. + """ + + if len(indices) % num_chunks != 0: + return [indices[i::num_chunks] for i in range(num_chunks)] + + num_indices_per_chunk = len(indices) // num_chunks + + chunks = [[] for _ in range(num_chunks)] + chunks_lengths = [0 for _ in range(num_chunks)] + for index in indices: + shortest_chunk = chunks_lengths.index(min(chunks_lengths)) + chunks[shortest_chunk].append(index) + chunks_lengths[shortest_chunk] += lengths[index] + if len(chunks[shortest_chunk]) == num_indices_per_chunk: + chunks_lengths[shortest_chunk] = float("inf") + + return chunks + + +def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + assert all(l != 0 for l in lengths), "Should not have zero length." + if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): + # all samples are in the same modality + return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) + mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) + lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) + + mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] + lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] + megabatch_size = world_size * batch_size + mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] + lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] + + last_mm = mm_megabatches[-1] + last_lang = lang_megabatches[-1] + additional_batch = last_mm + last_lang + megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] + megabatch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in megabatch_indices] + + if len(additional_batch) > 0: + megabatches.append(sorted(additional_batch)) + + return [i for megabatch in megabatches for i in megabatch] + + +def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + megabatch_size = world_size * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] + + return [i for megabatch in megabatches for batch in megabatch for i in batch] + + +class LengthGroupedSampler(Sampler): + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ + + def __init__( + self, + batch_size: int, + world_size: int, + lengths: Optional[List[int]] = None, + generator=None, + group_by_modality: bool = False, + ): + if lengths is None: + raise ValueError("Lengths must be provided.") + + self.batch_size = batch_size + self.world_size = world_size + self.lengths = lengths + self.generator = generator + self.group_by_modality = group_by_modality + + def __len__(self): + return len(self.lengths) + + def __iter__(self): + if self.group_by_modality: + indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + else: + indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + return iter(indices) + + +class LLaVATrainer(Trainer): + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + if self.args.group_by_modality_length: + lengths = self.train_dataset.modality_lengths + return LengthGroupedSampler( + self.args.train_batch_size, + world_size=self.args.world_size * self.args.gradient_accumulation_steps, + lengths=lengths, + group_by_modality=True, + ) + else: + return super()._get_train_sampler() + + + def ocreate_accelerator_and_postprocess(self): + grad_acc_kwargs = {} + if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None: + grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs + + # check if num_steps is attempted to be passed in gradient_accumulation_kwargs + if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1: + # raise because we do not know which setting is intended. + raise ValueError( + "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" + "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." + ) + elif "num_steps" not in grad_acc_kwargs: + # take the gradient_accumulation_steps setting from TrainingArguments. + grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps + + grad_acc_kwargs["sync_with_dataloader"] = False + + from accelerate.utils import ( + GradientAccumulationPlugin, + ) + gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) + + accelerator_config = self.args.accelerator_config.to_dict() + + if is_accelerate_available("0.28.0"): + from accelerate.utils import DataLoaderConfiguration + + if is_accelerate_available("0.28.0"): + dataloader_config = DataLoaderConfiguration( + split_batches=accelerator_config.pop("split_batches"), + dispatch_batches=accelerator_config.pop("dispatch_batches"), + even_batches=accelerator_config.pop("even_batches"), + use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), + ) + non_blocking = accelerator_config.pop("non_blocking") + if not is_accelerate_available("0.30.0"): + if non_blocking: + raise ImportError( + "`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature." + ) + else: + if non_blocking and not self.args.dataloader_pin_memory: + logger.warning( + "`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both." + ) + dataloader_config.non_blocking = non_blocking + # this would have been updated above, no need for it anymore + accelerator_config.pop("gradient_accumulation_kwargs") + + args = { + "deepspeed_plugin": self.args.deepspeed_plugin, + "gradient_accumulation_plugin": gradient_accumulation_plugin, + } + if is_accelerate_available("0.28.0"): + args["dataloader_config"] = dataloader_config + else: + args.update(accelerator_config) + + # create accelerator object + from .acc import Accelerator + self.accelerator = Accelerator(**args) + # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag + self.gather_function = self.accelerator.gather_for_metrics + + # deepspeed and accelerate flags covering both trainer args and accelerate launcher + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + + # post accelerator creation setup + if self.is_fsdp_enabled: + fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( + "limit_all_gathers", fsdp_plugin.limit_all_gathers + ) + if is_accelerate_available("0.23.0"): + fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get( + "activation_checkpointing", fsdp_plugin.activation_checkpointing + ) + if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: + raise ValueError( + "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " + "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " + "when using FSDP." + ) + + if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: + self.propagate_args_to_deepspeed() + + # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end` + if ( + self.args.save_only_model + and (self.is_deepspeed_enabled or self.is_fsdp_enabled) + and self.args.load_best_model_at_end + ): + wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" + raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.") + + # `auto_find_batch_size` isn't yet supported with DeepSpeed/FSDP + if (self.is_deepspeed_enabled or self.is_fsdp_enabled) and self.args.auto_find_batch_size: + wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" + raise NotImplementedError(f"`{wrapper}` doesn't support `auto_find_batch_size`.") + + + def otraining_step(self, model: nn.Module, inputs) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + model.train() + inputs = self._prepare_inputs(inputs) + + from icecream import ic + ic("inputs_prepared") + + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + + from icecream import ic + ic("loss_computed") + + del inputs + torch.cuda.empty_cache() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + self.accelerator.backward(loss) + + return loss.detach() / self.args.gradient_accumulation_steps + + def o_inner_training_loop( + self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None + ): + from icecream import ic + ic("INNER TRAINING") + self.accelerator.free_memory() + self._train_batch_size = batch_size + if self.args.auto_find_batch_size: + if self.state.train_batch_size != self._train_batch_size: + from accelerate.utils import release_memory + + (self.model_wrapped,) = release_memory(self.model_wrapped) + self.model_wrapped = self.model + + # Check for DeepSpeed *after* the intial pass and modify the config + if self.is_deepspeed_enabled: + # Temporarily unset `self.args.train_batch_size` + original_bs = self.args.per_device_train_batch_size + self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) + self.propagate_args_to_deepspeed(True) + self.args.per_device_train_batch_size = original_bs + self.state.train_batch_size = self._train_batch_size + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") + # Data loader and number of training steps + train_dataloader = self.get_train_dataloader() + if self.is_fsdp_xla_v2_enabled: + train_dataloader = tpu_spmd_dataloader(train_dataloader) + + # Setting up training control variables: + # number of training epochs: num_train_epochs + # number of training steps per epoch: num_update_steps_per_epoch + # total number of training steps to execute: max_steps + total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size + + len_dataloader = None + num_train_tokens = None + if has_length(train_dataloader): + len_dataloader = len(train_dataloader) + num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + num_examples = self.num_examples(train_dataloader) + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 + ) + # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's + # the best we can do. + num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = ( + self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps + ) + else: + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs + if args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs + elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size + max_steps = args.max_steps + # Setting a very large number of epochs so we go as many times as necessary over the iterator. + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = max_steps + num_examples = total_train_batch_size * args.max_steps + num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps + else: + raise ValueError( + "args.max_steps must be set to a positive value if dataloader does not have a length, was" + f" {args.max_steps}" + ) + + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + if self.args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP" + " (torchrun or torch.distributed.launch (deprecated))." + ) + else: + debug_overflow = DebugUnderflowOverflow(self.model) # noqa + + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled + + # We need to reset the scheduler, as its parameters may be different on subsequent calls + if self._created_lr_scheduler: + self.lr_scheduler = None + self._created_lr_scheduler = False + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + if not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + + from icecream import ic + ic("STATE") + self.state = TrainerState( + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ] + ) + self.state.is_hyper_param_search = trial is not None + self.state.train_batch_size = self._train_batch_size + + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + if args.gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + else: + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs + + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + + model = self._wrap_model(self.model_wrapped) + + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False + + if delay_optimizer_creation: + if use_accelerator_prepare: + self._fsdp_qlora_plugin_updates() + self.model = self.accelerator.prepare(self.model) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + self.model.train() + if hasattr(self.lr_scheduler, "step"): + if self.use_apex: + model = self.accelerator.prepare(self.model) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + if self.is_fsdp_enabled: + self.model = self.model_wrapped = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # ckpt loading + if resume_from_checkpoint is not None: + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint( + self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) + ) + elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), + # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + + self.state.epoch = 0 + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + self.compare_trainer_and_checkpoint_args(self.args, self.state) + self._load_callback_state() + epochs_trained = self.state.global_step // num_update_steps_per_epoch + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) + + # Update the references + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader + if self.hp_name is not None and self._trial is not None: + # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # parameter to Train when using DDP. + self.state.trial_name = self.hp_name(self._trial) + if trial is not None: + assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial + self.state.trial_params = hp_params(assignments) + else: + self.state.trial_params = None + # This should be the same if the state has been saved but in case the training arguments changed, it's safer + # to set this after the load. + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + # tr_loss is a tensor to avoid synchronization of TPUs through .item() + tr_loss = torch.tensor(0.0).to(args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + model.zero_grad() + grad_norm: Optional[float] = None + + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + total_batched_samples = 0 + from icecream import ic + for epoch in range(epochs_trained, num_train_epochs): + epoch_iterator = train_dataloader + if hasattr(epoch_iterator, "set_epoch"): + epoch_iterator.set_epoch(epoch) + + # Reset the past mems state at the beginning of each epoch if necessary. + if args.past_index >= 0: + self._past = None + + steps_in_epoch = ( + len(epoch_iterator) + if len_dataloader is not None + else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + rng_to_sync = False + steps_skipped = 0 + if steps_trained_in_current_epoch > 0: + epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + + step = -1 + for step, inputs in enumerate(epoch_iterator): + total_batched_samples += 1 + + if self.args.include_num_input_tokens_seen: + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + else: + input_device = inputs[main_input_name].device + self.state.num_input_tokens_seen += torch.sum( + self.accelerator.gather( + torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64) + ) + ).item() + + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + with self.accelerator.accumulate(model): + ic(step, "before_step", dist.get_rank(), step) + tr_loss_step = self.training_step(model, inputs) + ic(step, "after_step") + if ( + args.logging_nan_inf_filter + and not is_torch_xla_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + if tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" + ) + tr_loss += tr_loss_step + + self.current_flos += float(self.floating_point_ops(inputs)) + + is_last_step_and_steps_less_than_grad_acc = ( + steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch + ) + + from icecream import ic + ic(total_batched_samples, dist.get_rank()) + + if ( + total_batched_samples % args.gradient_accumulation_steps == 0 + or + # last step in epoch but step is always smaller than gradient_accumulation_steps + is_last_step_and_steps_less_than_grad_acc + ): + # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered + # in accelerate. So, explicitly enable sync gradients to True in that case. + from icecream import ic + ic("pre_sync", dist.get_rank()) + if is_last_step_and_steps_less_than_grad_acc: + self.accelerator.gradient_state._set_sync_gradients(True) + from icecream import ic + ic("post_sync", dist.get_rank()) + + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + # deepspeed does its own clipping + + from icecream import ic + ic("pre-clip", dist.get_rank()) + if is_sagemaker_mp_enabled() and args.fp16: + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) + elif self.use_apex: + # Revert to normal clipping otherwise, handling Apex or full precision + _grad_norm = nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + from icecream import ic + ic("post_clip", dist.get_rank()) + + if ( + is_accelerate_available() + and self.accelerator.distributed_type == DistributedType.DEEPSPEED + ): + grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() + else: + grad_norm = _grad_norm + + from icecream import ic + ic(grad_norm) + # Optimizer step + self.optimizer.step() + from icecream import ic + ic("post opt step", dist.get_rank()) + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() + + from icecream import ic + ic("pre zero grad", dist.get_rank()) + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + from icecream import ic + ic("post control", dist.get_rank()) + + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) + from icecream import ic + ic("post log", dist.get_rank()) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + ic("after callback", dist.get_rank()) + + if self.control.should_epoch_stop or self.control.should_training_stop: + # PyTorch/XLA relies on the data loader to insert the mark_step for + # each step. Since we are breaking the loop early, we need to manually + # insert the mark_step here. + break + if step < 0: + logger.warning( + "There seems to be not a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) + + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + # Wait for everyone to get here so we are sure the model has been saved by process 0. + if args.parallel_mode == ParallelMode.DISTRIBUTED: + dist.barrier() + + self._load_best_model() + + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError + train_loss = self._total_loss_scalar / effective_global_step + + metrics = speed_metrics( + "train", + start_time, + num_samples=num_train_samples, + num_steps=self.state.max_steps, + num_tokens=num_train_tokens, + ) + self.store_flos() + metrics["total_flos"] = self.state.total_flos + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: + for checkpoint in checkpoints_sorted: + if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + # Wait for the checkpoint to be uploaded. + self._finish_current_push() + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None: + self._deactivate_neftune(self.model) + + return TrainOutput(self.state.global_step, train_loss, metrics) + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + if is_sagemaker_mp_enabled(): + return super().create_optimizer() + + opt_model = self.model + + if self.optimizer is None: + decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + + if self.args.mm_vision_lr is not None: + def include_vision_params(name): + return "vision_tower" not in name + else: + def include_vision_params(name): + return True + + if self.args.mm_projector_lr is not None: + projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + "lr": self.args.mm_projector_lr, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + "lr": self.args.mm_projector_lr, + }, + ] + else: + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and include_vision_params(n)) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and include_vision_params(n)) + ], + "weight_decay": 0.0, + }, + ] + + if self.args.mm_vision_lr is not None: + vision_tower_parameters = [name for name, _ in opt_model.named_parameters() if "vision_tower" in name] + optimizer_grouped_parameters.extend([ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in vision_tower_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + "lr": self.args.mm_vision_lr, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in vision_tower_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + "lr": self.args.mm_vision_lr, + }, + ]) + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped/2**20}M params") + + return self.optimizer + + def _save_checkpoint(self, model, trial, metrics=None): + if getattr(self.args, 'tune_mm_mlp_adapter', False): + from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + + # Only save Adapter + keys_to_match = ['mm_projector', 'vision_resampler'] + if getattr(self.args, "use_im_start_end", False): + keys_to_match.extend(['embed_tokens', 'embed_in']) + + weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) + + if self.args.local_rank == 0 or self.args.local_rank == -1: + self.model.config.save_pretrained(output_dir) + torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) + # else: + super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # if getattr(self.args, 'tune_mm_mlp_adapter', False): + # pass + # else: + super(LLaVATrainer, self)._save(output_dir, state_dict) \ No newline at end of file diff --git a/ola_vlm/train/old/probe_train.py b/ola_vlm/train/old/probe_train.py new file mode 100644 index 0000000000000000000000000000000000000000..b3956210c6fa44045596abe37bb49939217da1e6 --- /dev/null +++ b/ola_vlm/train/old/probe_train.py @@ -0,0 +1,1055 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +from dataclasses import dataclass, field +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List + +import torch + +import transformers +import tokenizers + +from ola_vlm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from torch.utils.data import Dataset +from ola_vlm.train.llava_trainer import LLaVATrainer + +from llava import conversation as conversation_lib +from ola_vlm.model import * +from ola_vlm.mm_utils import tokenizer_image_token + +from PIL import Image + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +from packaging import version +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_mm_mlp_adapter: bool = field(default=False) + use_s2: bool = field(default=False) + s2_scales: Optional[str] = field(default="336,1008") + vision_tower: Optional[str] = field(default=None) + mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + pretrain_mm_mlp_adapter: Optional[str] = field(default=None) + mm_projector_type: Optional[str] = field(default='linear') + mm_use_im_start_end: bool = field(default=False) + mm_use_im_patch_token: bool = field(default=True) + mm_patch_merge_type: Optional[str] = field(default='flat') + mm_vision_select_feature: Optional[str] = field(default="patch") + + # probe + image_generator: Optional[str] = field(default="runwayml/stable-diffusion-v1-5") + probe_depth: Optional[int] = 1 + probe_dim_head: Optional[int] = 32 + probe_num_heads: Optional[int] = 4 + probe_num_tokens: Optional[int] = 77 + probe_output_dim: Optional[int] = 768 + probe_ff_mult: Optional[int] = 1 + + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + lazy_preprocess: bool = False + is_multimodal: bool = False + image_folder: Optional[str] = field(default=None) + image_aspect_ratio: str = 'square' + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + freeze_mm_mlp_adapter: bool = field(default=False) + mpt_attn_impl: Optional[str] = field(default="triton") + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + mm_projector_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, + output_dir: str): + """Collects the state dict and dump to disk.""" + + if getattr(trainer.args, "tune_mm_mlp_adapter", False): + # Only save Adapter + keys_to_match = ['mm_projector'] + if getattr(trainer.args, "use_im_start_end", False): + keys_to_match.extend(['embed_tokens', 'embed_in']) + + weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) + trainer.model.config.save_pretrained(output_dir) + + current_folder = output_dir.split('/')[-1] + parent_folder = os.path.dirname(output_dir) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + if current_folder.startswith('checkpoint-'): + mm_projector_folder = os.path.join(parent_folder, "mm_projector") + os.makedirs(mm_projector_folder, exist_ok=True) + torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) + else: + torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) + return + + if trainer.deepspeed: + torch.cuda.synchronize() + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = { + key: value.cpu() + for key, value in state_dict.items() + } + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn(strings: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [ + tokenized.input_ids[0] for tokenized in tokenized_list + ] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = 'unknown' + sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + + sentence["value"] + END_SIGNAL) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def preprocess_multimodal( + sources: Sequence[str], + data_args: DataArguments +) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + if DEFAULT_IMAGE_TOKEN in sentence['value']: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() + sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] + sentence['value'] = sentence['value'].strip() + if "mmtag" in conversation_lib.default_conversation.version: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') + replace_token = DEFAULT_IMAGE_TOKEN + if data_args.mm_use_im_start_end: + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + return sources + + +def preprocess_phi_3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + return dict( + input_ids=input_ids, + ) + +def preprocess_llama_3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack( + [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx + 2])) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + # if i > 0: + # round_len -= 1 + # instruction_len -= 1 + + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_llama_2( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 + + # Mask targets + sep = "[/INST] " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_v1( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + return dict( + input_ids=input_ids + ) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.version == "llama3": + return preprocess_llama_3(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "phi3": + return preprocess_phi_3(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: + return preprocess_llama_2(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_image=has_image) + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{conversation_lib.default_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + # tokenize conversations + def get_tokenize_len(prompts): + return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] + + if has_image: + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + else: + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + if has_image: + tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) + else: + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + +def parse_json(file): + with open(file) as f: + data = json.load(f) + return data + +def prepare_coco(json_file): + from tqdm import tqdm + + coco_data = parse_json(json_file) + + id_to_filename = {image["id"]: image["file_name"] for image in coco_data["images"]} + processed_image_ids = set() + list_data_dict = [] + + for annotation in tqdm(coco_data["annotations"]): + image_id = annotation["image_id"] + if image_id in processed_image_ids: + continue + caption = annotation["caption"] + file_name = id_to_filename[image_id] + processed_image_ids.add(image_id) + + question = "Describe the image in two lines.\n" + conversations = [ + {"from": "human", "value": question}, + {"from": "gpt", "value": ""} + ] + + list_data_dict.append( + { + "conversations": conversations, + "image": file_name, + "caption": caption, + } + ) + + return list_data_dict + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments): + super(LazySupervisedDataset, self).__init__() + list_data_dict = prepare_coco(data_path) + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.data_args = data_args + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + img_tokens = 128 if 'image' in sample else 0 + length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) + return length_list + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) + cur_len = cur_len if 'image' in sample else -cur_len + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + if 'image' in sources[0]: + image_file = self.list_data_dict[i]['image'] + image_folder = self.data_args.image_folder + processor = self.data_args.image_processor + pil_image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + if self.data_args.image_aspect_ratio == 'pad': + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(pil_image, tuple(int(x*255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + else: + image = processor.preprocess(pil_image, return_tensors='pt')['pixel_values'][0] + sources = preprocess_multimodal( + copy.deepcopy([e["conversations"] for e in sources]), + self.data_args) + else: + sources = copy.deepcopy([e["conversations"] for e in sources]) + data_dict = preprocess( + sources, + self.tokenizer, + has_image=('image' in self.list_data_dict[i])) + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0]) + + # image exist in the data + if 'image' in self.list_data_dict[i]: + data_dict['image'] = image + data_dict["pil_image"] = pil_image + data_dict["caption"] = self.list_data_dict[i]['caption'] + elif self.data_args.is_multimodal: + # image does not exist in the data, but the model is multimodal + try: + crop_size = self.data_args.image_processor.crop_size + except: + crop_size = self.data_args.image_processor.size + data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) + return data_dict + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids = [instance["input_ids"] for instance in instances] + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + + input_ids = input_ids[:, :self.tokenizer.model_max_length] + batch = dict( + input_ids=input_ids, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + if 'image' in instances[0]: + images = [instance['image'] for instance in instances] + if all(x is not None and x.shape == images[0].shape for x in images): + batch['images'] = torch.stack(images) + else: + batch['images'] = images + + if 'pil_image' in instances[0]: + pil_images = [instance['pil_image'] for instance in instances] + batch['pil_images'] = pil_images + + if 'caption' in instances[0]: + captions = [instance['caption'] for instance in instances] + batch['captions'] = captions + + return batch + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, + data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=data_args.data_path, + data_args=data_args) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator) + + +def train(attn_implementation=None): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + + bnb_model_from_pretrained_args = {} + if training_args.bits in [4, 8]: + from transformers import BitsAndBytesConfig + bnb_model_from_pretrained_args.update(dict( + device_map={"": training_args.device}, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + quantization_config=BitsAndBytesConfig( + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + llm_int8_skip_modules=["mm_projector"], + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=training_args.double_quant, + bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} + ) + )) + + if model_args.vision_tower is not None: + if 'phi' in model_args.model_name_or_path.lower(): + model = SherlockProbeLlavaPhi3ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + model = SherlockProbeLlavaLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + if 'phi' in model_args.model_name_or_path.lower(): + model = transformers.Phi3ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + model = transformers.LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + model.config.use_cache = False + + model.model.requires_grad_(False) + + if training_args.bits in [4, 8]: + from peft import prepare_model_for_kbit_training + model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) + + if training_args.gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if training_args.lora_enable: + from peft import LoraConfig, get_peft_model + lora_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=find_all_linear_names(model), + lora_dropout=training_args.lora_dropout, + bias=training_args.lora_bias, + task_type="CAUSAL_LM", + ) + if training_args.bits == 16: + if training_args.bf16: + model.to(torch.bfloat16) + if training_args.fp16: + model.to(torch.float16) + rank0_print("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + tokenizer.pad_token = tokenizer.unk_token + if tokenizer.pad_token_id is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=""), + tokenizer=tokenizer, + model=model, + ) + + if model_args.version in conversation_lib.conv_templates: + conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] + else: + conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] + + vision_tower = model.get_vision_tower() + + if vision_tower is None: + model.get_model().initialize_vision_modules( + model_args=model_args, + fsdp=training_args.fsdp + ) + vision_tower = model.get_vision_tower() + + if not vision_tower.is_loaded: + vision_tower.load_model(device_map={"": training_args.device}) + vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + data_args.image_processor = vision_tower.image_processor + data_args.is_multimodal = True + + model.config.image_aspect_ratio = data_args.image_aspect_ratio + model.config.tokenizer_padding_side = tokenizer.padding_side + model.config.tokenizer_model_max_length = tokenizer.model_max_length + + model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter + model.requires_grad_(False) + + model.config.image_gen = { + "probe_depth": model_args.probe_depth, + "probe_dim_head": model_args.probe_dim_head, + "probe_num_heads": model_args.probe_num_heads, + "probe_num_tokens": model_args.probe_num_tokens, + "probe_output_dim": model_args.probe_output_dim, + "probe_ff_mult": model_args.probe_ff_mult, + } + model.config.image_generator = model_args.image_generator + + model.init_probes(model.config) + model.init_image_generator(model.config) + + for p in model.lm_head.parameters(): + p.requires_grad = False + + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = False + + model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = False + + for r in model.resamplers: + for p in r.parameters(): + p.requires_grad = True + + try: + for r in model.img_resamplers: + for p in r.parameters(): + p.requires_grad = True + except: + pass + + + # import torch.distributed as dist + # from icecream import ic + # if dist.get_rank() == 0: + # for n, p in model.named_parameters(): + # if p.requires_grad: + # ic(n) + + if training_args.bits in [4, 8]: + model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) + + model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_projector_lr = training_args.mm_projector_lr + training_args.use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token + model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) + + if training_args.bits in [4, 8]: + from peft.tuners.lora import LoraLayer + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if training_args.bf16: + module = module.to(torch.bfloat16) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if training_args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + data_module = make_supervised_data_module(tokenizer=tokenizer, + data_args=data_args) + trainer = LLaVATrainer(model=model, + tokenizer=tokenizer, + args=training_args, + **data_module) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + model.config.use_cache = True + + if training_args.lora_enable: + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), training_args.lora_bias + ) + non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( + model.named_parameters() + ) + if training_args.local_rank == 0 or training_args.local_rank == -1: + model.config.save_pretrained(training_args.output_dir) + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) + else: + safe_save_model_for_hf_trainer(trainer=trainer, + output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/ola_vlm/train/old/probe_train_mem.py b/ola_vlm/train/old/probe_train_mem.py new file mode 100644 index 0000000000000000000000000000000000000000..28575cf0d4fe96fe1747215fb3d08d922b8b372a --- /dev/null +++ b/ola_vlm/train/old/probe_train_mem.py @@ -0,0 +1,7 @@ +from ola_vlm.train.probe_train import train + +if __name__ == "__main__": + try: + train(attn_implementation="flash_attention_2") + except: + train(attn_implementation="eager") diff --git a/ola_vlm/train/old/sherlock_train.py b/ola_vlm/train/old/sherlock_train.py new file mode 100644 index 0000000000000000000000000000000000000000..06dd357b516abf1b59dcee2565cb5d8747a084da --- /dev/null +++ b/ola_vlm/train/old/sherlock_train.py @@ -0,0 +1,1103 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +from dataclasses import dataclass, field +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List + +import torch + +import transformers +import tokenizers + +from ola_vlm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from torch.utils.data import Dataset +from ola_vlm.train.llava_trainer import LLaVATrainer + +from llava import conversation as conversation_lib +from ola_vlm.model import * +from ola_vlm.mm_utils import tokenizer_image_token + +from PIL import Image + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +from packaging import version +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_mm_mlp_adapter: bool = field(default=False) + use_s2: bool = field(default=False) + s2_scales: Optional[str] = field(default="336,1008") + vision_tower: Optional[str] = field(default=None) + mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + pretrain_mm_mlp_adapter: Optional[str] = field(default=None) + mm_projector_type: Optional[str] = field(default='linear') + mm_use_im_start_end: bool = field(default=False) + mm_use_im_patch_token: bool = field(default=True) + mm_patch_merge_type: Optional[str] = field(default='flat') + mm_vision_select_feature: Optional[str] = field(default="patch") + + # resamplers + image_generator: Optional[str] = field(default="runwayml/stable-diffusion-v1-5") + probe_depth: Optional[int] = 1 + probe_dim_head: Optional[int] = 32 + probe_num_heads: Optional[int] = 4 + probe_num_tokens: Optional[int] = 77 + probe_output_dim: Optional[int] = 768 + probe_ff_mult: Optional[int] = 1 + layer_indices: Optional[str] = "8-12-16,24,30" + img_loss_weight: Optional[float] = 0.5 + + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + lazy_preprocess: bool = False + is_multimodal: bool = False + image_folder: Optional[str] = field(default=None) + image_aspect_ratio: str = 'square' + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + freeze_mm_mlp_adapter: bool = field(default=False) + mpt_attn_impl: Optional[str] = field(default="triton") + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + mm_projector_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, + output_dir: str): + """Collects the state dict and dump to disk.""" + + if getattr(trainer.args, "tune_mm_mlp_adapter", False): + # Only save Adapter + keys_to_match = ['mm_projector'] + if getattr(trainer.args, "use_im_start_end", False): + keys_to_match.extend(['embed_tokens', 'embed_in']) + + weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) + trainer.model.config.save_pretrained(output_dir) + + current_folder = output_dir.split('/')[-1] + parent_folder = os.path.dirname(output_dir) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + if current_folder.startswith('checkpoint-'): + mm_projector_folder = os.path.join(parent_folder, "mm_projector") + os.makedirs(mm_projector_folder, exist_ok=True) + torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) + else: + torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) + + if trainer.deepspeed: + torch.cuda.synchronize() + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = { + key: value.cpu() + for key, value in state_dict.items() + } + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn(strings: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [ + tokenized.input_ids[0] for tokenized in tokenized_list + ] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = 'unknown' + sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + + sentence["value"] + END_SIGNAL) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def preprocess_multimodal( + sources: Sequence[str], + data_args: DataArguments +) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + if DEFAULT_IMAGE_TOKEN in sentence['value']: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() + sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] + sentence['value'] = sentence['value'].strip() + if "mmtag" in conversation_lib.default_conversation.version: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') + replace_token = DEFAULT_IMAGE_TOKEN + if data_args.mm_use_im_start_end: + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + return sources + + +def preprocess_phi_3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i > 0: + round_len -= 2 + instruction_len -= 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + +def preprocess_llama_3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack( + [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx + 2])) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + # if i > 0: + # round_len -= 1 + # instruction_len -= 1 + + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_llama_2( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 + + # Mask targets + sep = "[/INST] " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_v1( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.version == "llama3": + return preprocess_llama_3(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "phi3": + return preprocess_phi_3(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: + return preprocess_llama_2(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_image=has_image) + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{conversation_lib.default_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + # tokenize conversations + def get_tokenize_len(prompts): + return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] + + if has_image: + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + else: + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + if has_image: + tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) + else: + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments): + super(LazySupervisedDataset, self).__init__() + list_data_dict = json.load(open(data_path, "r")) + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.data_args = data_args + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + img_tokens = 128 if 'image' in sample else 0 + length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) + return length_list + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) + cur_len = cur_len if 'image' in sample else -cur_len + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + if 'image' in sources[0]: + image_file = self.list_data_dict[i]['image'] + image_folder = self.data_args.image_folder + processor = self.data_args.image_processor + pil_image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + if self.data_args.image_aspect_ratio == 'pad': + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + else: + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + sources = preprocess_multimodal( + copy.deepcopy([e["conversations"] for e in sources]), + self.data_args) + else: + sources = copy.deepcopy([e["conversations"] for e in sources]) + data_dict = preprocess( + sources, + self.tokenizer, + has_image=('image' in self.list_data_dict[i])) + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], + labels=data_dict["labels"][0]) + + # image exist in the data + if 'image' in self.list_data_dict[i]: + data_dict['image'] = image + data_dict["pil_image"] = pil_image + elif self.data_args.is_multimodal: + # image does not exist in the data, but the model is multimodal + try: + crop_size = self.data_args.image_processor.crop_size + except: + crop_size = self.data_args.image_processor.size + data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) + data_dict['pil_image'] = Image.new('RGB', (crop_size['width'], crop_size['height']), color='black') + return data_dict + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] + for key in ("input_ids", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, + batch_first=True, + padding_value=IGNORE_INDEX) + input_ids = input_ids[:, :self.tokenizer.model_max_length] + labels = labels[:, :self.tokenizer.model_max_length] + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + if 'image' in instances[0]: + images = [instance['image'] for instance in instances] + if all(x is not None and x.shape == images[0].shape for x in images): + batch['images'] = torch.stack(images) + else: + batch['images'] = images + + if 'pil_image' in instances[0]: + pil_images = [instance['pil_image'] for instance in instances] + batch['pil_images'] = pil_images + + return batch + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, + data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=data_args.data_path, + data_args=data_args) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator) + + +def train(attn_implementation=None): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + + bnb_model_from_pretrained_args = {} + if training_args.bits in [4, 8]: + from transformers import BitsAndBytesConfig + bnb_model_from_pretrained_args.update(dict( + device_map={"": training_args.device}, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + quantization_config=BitsAndBytesConfig( + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + llm_int8_skip_modules=["mm_projector"], + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=training_args.double_quant, + bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} + ) + )) + + if model_args.vision_tower is not None: + if 'phi' in model_args.model_name_or_path.lower(): + model = SherlockLlavaPhi3ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + model = LlavaLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + if 'phi' in model_args.model_name_or_path.lower(): + model = transformers.Phi3ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + model = transformers.LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + model.config.use_cache = False + + if model_args.freeze_backbone: + model.model.requires_grad_(False) + + if training_args.bits in [4, 8]: + from peft import prepare_model_for_kbit_training + model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) + + if training_args.gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if training_args.lora_enable: + from peft import LoraConfig, get_peft_model + lora_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=find_all_linear_names(model), + lora_dropout=training_args.lora_dropout, + bias=training_args.lora_bias, + task_type="CAUSAL_LM", + ) + if training_args.bits == 16: + if training_args.bf16: + model.to(torch.bfloat16) + if training_args.fp16: + model.to(torch.float16) + rank0_print("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + tokenizer.pad_token = tokenizer.unk_token + if tokenizer.pad_token_id is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=""), + tokenizer=tokenizer, + model=model, + ) + + if model_args.version in conversation_lib.conv_templates: + conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] + else: + conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] + + if model_args.vision_tower is not None: + model.get_model().initialize_vision_modules( + model_args=model_args, + fsdp=training_args.fsdp + ) + + vision_tower = model.get_vision_tower() + vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + data_args.image_processor = vision_tower.image_processor + data_args.is_multimodal = True + + model.config.image_grid_pinpoints = [[336,672], [672,336], [672,672], [1008,336], [336,1008]] + model.config.image_aspect_ratio = data_args.image_aspect_ratio + model.config.tokenizer_padding_side = tokenizer.padding_side + model.config.tokenizer_model_max_length = tokenizer.model_max_length + + model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter + if model_args.tune_mm_mlp_adapter: + model.requires_grad_(False) + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = True + + model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter + if training_args.freeze_mm_mlp_adapter: + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = False + + if training_args.bits in [4, 8]: + model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) + + model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_projector_lr = training_args.mm_projector_lr + training_args.use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token + model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) + + model.config.use_s2 = model_args.use_s2 + model.config.s2_scales = model_args.s2_scales + + model.config.image_gen = { + "probe_depth": model_args.probe_depth, + "probe_dim_head": model_args.probe_dim_head, + "probe_num_heads": model_args.probe_num_heads, + "probe_num_tokens": model_args.probe_num_tokens, + "probe_output_dim": model_args.probe_output_dim, + "probe_ff_mult": model_args.probe_ff_mult, + "layer_indices": model_args.layer_indices, + "img_loss_weight": model_args.img_loss_weight, + } + model.config.image_generator = model_args.image_generator + + model.init_probes(model.config) + model.init_image_generator(model.config) + + # import torch.distributed as dist + # from icecream import ic + # if dist.get_rank() == 0: + # for n, p in model.named_parameters(): + # if p.requires_grad: + # ic(n) + + if training_args.bits in [4, 8]: + from peft.tuners.lora import LoraLayer + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if training_args.bf16: + module = module.to(torch.bfloat16) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if training_args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + data_module = make_supervised_data_module(tokenizer=tokenizer, + data_args=data_args) + trainer = LLaVATrainer(model=model, + tokenizer=tokenizer, + args=training_args, + **data_module) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + model.config.use_cache = True + + if training_args.lora_enable: + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), training_args.lora_bias + ) + non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( + model.named_parameters() + ) + if training_args.local_rank == 0 or training_args.local_rank == -1: + model.config.save_pretrained(training_args.output_dir) + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) + else: + safe_save_model_for_hf_trainer(trainer=trainer, + output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/ola_vlm/train/old/sherlock_train_mem.py b/ola_vlm/train/old/sherlock_train_mem.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c2e7a3fa7bd0adc3f4b96479584fef45f6cc2d --- /dev/null +++ b/ola_vlm/train/old/sherlock_train_mem.py @@ -0,0 +1,5 @@ +from ola_vlm.train.sherlock_train import train + +if __name__ == "__main__": + train(attn_implementation="flash_attention_2") + # train(attn_implementation="eager") diff --git a/ola_vlm/train/probe_dsg_train.py b/ola_vlm/train/probe_dsg_train.py new file mode 100644 index 0000000000000000000000000000000000000000..98483538d713139402fa9f397aa6b9e323190a4c --- /dev/null +++ b/ola_vlm/train/probe_dsg_train.py @@ -0,0 +1,1262 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +from dataclasses import dataclass, field +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List + +import numpy as np +import torch + +import transformers +import tokenizers + +from ola_vlm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from torch.utils.data import Dataset +from ola_vlm.train.llava_trainer import LLaVATrainer + +from llava import conversation as conversation_lib +from ola_vlm.model import * +from ola_vlm.mm_utils import process_highres_image, process_anyres_image, process_highres_image_crop_split, tokenizer_image_token + +from PIL import Image, ImageFile +from transformers import set_seed + +set_seed(42) + +# Enable loading of truncated images +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +from packaging import version +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_mm_mlp_adapter: bool = field(default=False) + use_s2: bool = field(default=False) + s2_scales: Optional[str] = field(default="336,1008") + vision_tower: Optional[str] = field(default=None) + mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + pretrain_mm_mlp_adapter: Optional[str] = field(default=None) + mm_projector_type: Optional[str] = field(default='linear') + mm_use_im_start_end: bool = field(default=False) + mm_use_im_patch_token: bool = field(default=True) + mm_patch_merge_type: Optional[str] = field(default='flat') + mm_vision_select_feature: Optional[str] = field(default="patch") + + # visual interpretors + image_generator: Optional[str] = field(default="stabilityai/stable-diffusion-2-1-unclip") + image_segmentor: Optional[str] = field(default="shi-labs/oneformer_coco_swin_large") # sam_vit_l_0b3195.pth + depth_estimator: Optional[str] = field(default="depth_anything_v2_vitl.pth") + + mode: Optional[str] = field(default="seg") + + # gen + img_head_depth: Optional[int] = 1 + img_head_dim_head: Optional[int] = 32 + img_head_num_heads: Optional[int] = 4 + img_head_num_tokens: Optional[int] = 1 + img_head_output_dim: Optional[int] = 1024 + img_head_ff_mult: Optional[int] = 1 + + # seg + seg_head_depth: Optional[int] = 1 + seg_head_dim_head: Optional[int] = 32 + seg_head_num_heads: Optional[int] = 4 + seg_head_num_tokens: Optional[int] = 576 + seg_head_output_dim: Optional[int] = 1536 # 256 + seg_head_ff_mult: Optional[int] = 1 + seg_teacher: Optional[str] = "oneformer" # "sam" + + # depth + depth_head_depth: Optional[int] = 1 + depth_head_dim_head: Optional[int] = 32 + depth_head_num_heads: Optional[int] = 4 + depth_head_num_tokens: Optional[int] = 576 + depth_head_output_dim: Optional[int] = 1024 + depth_head_ff_mult: Optional[int] = 1 + + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + lazy_preprocess: bool = False + is_multimodal: bool = False + image_folder: Optional[str] = field(default=None) + image_aspect_ratio: str = 'square' + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + freeze_mm_mlp_adapter: bool = field(default=False) + mpt_attn_impl: Optional[str] = field(default="triton") + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + mm_projector_lr: Optional[float] = None + mm_vision_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, + output_dir: str): + """Collects the state dict and dump to disk.""" + + if trainer.deepspeed: + torch.cuda.synchronize() + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = { + key: value.cpu() + for key, value in state_dict.items() + } + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn(strings: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [ + tokenized.input_ids[0] for tokenized in tokenized_list + ] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = 'unknown' + sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + + sentence["value"] + END_SIGNAL) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def preprocess_multimodal( + sources: Sequence[str], + data_args: DataArguments +) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + if DEFAULT_IMAGE_TOKEN in sentence['value']: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() + sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] + sentence['value'] = sentence['value'].strip() + if "mmtag" in conversation_lib.default_conversation.version: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') + replace_token = DEFAULT_IMAGE_TOKEN + if data_args.mm_use_im_start_end: + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + return sources + + +def preprocess_phi_3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + return dict( + input_ids=input_ids, + labels=targets, + ) + +def preprocess_llama_3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack( + [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx + 2])) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + # if i > 0: + # round_len -= 1 + # instruction_len -= 1 + + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_llama_2( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 + + # Mask targets + sep = "[/INST] " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + +def preprocess_qwen( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False, + system_message: str = "You are a helpful assistant." + ) -> Dict: + # roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"} + roles = {"human": "user", "gpt": "assistant"} + + # Add image tokens to tokenizer as a special tokens + # Use a deepcopy of tokenizer so that we don't modify on the tokenizer + tokenizer = copy.deepcopy(tokenizer) + # When there is actually an image, we add the image tokens as a special token + if has_image: + tokenizer.add_tokens([""], special_tokens=True) + + image_token_index = tokenizer.convert_tokens_to_ids("") + im_start, im_end = tokenizer.additional_special_tokens_ids + # unmask_tokens = ["<|im_start|>", "<|im_start|>", "\n"] + unmask_tokens_idx = [198, im_start, im_end] + nl_tokens = tokenizer("\n").input_ids + + # Reset Qwen chat templates so that it won't include system message every time we apply + chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + tokenizer.chat_template = chat_template + + # _system = tokenizer("system").input_ids + nl_tokens + # _user = tokenizer("user").input_ids + nl_tokens + # _assistant = tokenizer("assistant").input_ids + nl_tokens + + # Apply prompt templates + input_ids, targets = [], [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != roles["human"]: + source = source[1:] + + input_id, target = [], [] + + # New version, use apply chat template + # Build system message for each sentence + input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) + target += [IGNORE_INDEX] * len(input_id) + + for conv in source: + # Make sure llava data can load + try: + role = conv["role"] + content = conv["content"] + except: + role = conv["from"] + content = conv["value"] + + role = roles.get(role, role) + + conv = [{"role" : role, "content" : content}] + encode_id = tokenizer.apply_chat_template(conv) + input_id += encode_id + if role in ["user", "system"]: + target += [IGNORE_INDEX] * len(encode_id) + else: + target += encode_id + + assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" + for idx, encode_id in enumerate(input_id): + if encode_id in unmask_tokens_idx: + target[idx] = encode_id + if encode_id == image_token_index: + input_id[idx] = IMAGE_TOKEN_INDEX + input_ids.append(input_id) + targets.append(target) + input_ids = torch.tensor(input_ids, dtype=torch.long) + targets = torch.tensor(targets, dtype=torch.long) + + return dict( + input_ids=input_ids, # tensor(bs x seq_len) + labels=targets, # tensor(bs x seq_len) + ) + +def preprocess_v1( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.version == "llama3": + return preprocess_llama_3(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "phi3": + return preprocess_phi_3(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "qwen": + return preprocess_qwen(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: + return preprocess_llama_2(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_image=has_image) + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{conversation_lib.default_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + # tokenize conversations + def get_tokenize_len(prompts): + return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] + + if has_image: + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + else: + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + if has_image: + tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) + else: + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + +def parse_json(file): + with open(file) as f: + data = json.load(f) + return data + +def prepare_coco(json_file): + from tqdm import tqdm + + coco_data = parse_json(json_file) + + id_to_filename = {image["id"]: image["file_name"] for image in coco_data["images"]} + processed_image_ids = set() + list_data_dict = [] + + for annotation in tqdm(coco_data["annotations"]): + image_id = annotation["image_id"] + if image_id in processed_image_ids: + continue + file_name = id_to_filename[image_id] + processed_image_ids.add(image_id) + + question = "Describe the image in two lines.\n" + conversations = [ + {"from": "human", "value": question}, + {"from": "gpt", "value": "n"} + ] + + list_data_dict.append( + { + "conversations": conversations, + "image": "train2017/" + file_name, + } + ) + + return list_data_dict + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments): + super(LazySupervisedDataset, self).__init__() + list_data_dict = prepare_coco(data_path) + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.data_args = data_args + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + img_tokens = 128 if 'image' in sample else 0 + length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) + return length_list + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) + cur_len = cur_len if 'image' in sample else -cur_len + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + + if 'image' in sources[0]: + image_file = self.list_data_dict[i]['image'] + image_folder = self.data_args.image_folder + processor = self.data_args.image_processor + try: + crop_size = self.data_args.image_processor.crop_size + except: + crop_size = self.data_args.image_processor.size + pil_image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + image_size = image.size + + if self.data_args.image_aspect_ratio == "highres": + image = process_highres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints) + elif self.data_args.image_aspect_ratio == "anyres" or "anyres_max" in self.data_args.image_aspect_ratio: + image = process_anyres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints) + elif self.data_args.image_aspect_ratio == "crop_split": + image = process_highres_image_crop_split(image, self.data_args) + elif self.data_args.image_aspect_ratio == 'pad': + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + else: + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + sources = preprocess_multimodal( + copy.deepcopy([e["conversations"] for e in sources]), + self.data_args) + else: + sources = copy.deepcopy([e["conversations"] for e in sources]) + + data_dict = preprocess( + sources, + self.tokenizer, + has_image=('image' in self.list_data_dict[i])) + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], + labels=data_dict["labels"][0]) + + # image exist in the data + if 'image' in self.list_data_dict[i]: + data_dict['image'] = image + data_dict['pil_image'] = pil_image + data_dict['image_size'] = image_size + elif self.data_args.is_multimodal: + # image does not exist in the data, but the model is multimodal + try: + crop_size = self.data_args.image_processor.crop_size + except: + crop_size = self.data_args.image_processor.size + data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) + data_dict['pil_image'] = Image.new('RGB', (crop_size['width'], crop_size['height']), color='black') + data_dict['image_size'] = (crop_size['width'], crop_size['height']) + + return data_dict + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] + for key in ("input_ids", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, + batch_first=True, + padding_value=IGNORE_INDEX) + input_ids = input_ids[:, :self.tokenizer.model_max_length] + labels = labels[:, :self.tokenizer.model_max_length] + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + if 'image' in instances[0]: + images = [instance['image'] for instance in instances] + if all(x is not None and x.shape == images[0].shape for x in images): + batch['images'] = torch.stack(images) + else: + batch['images'] = images + batch['image_sizes'] = [instance['image_size'] for instance in instances] + + if 'pil_image' in instances[0]: + pil_images = [instance['pil_image'] for instance in instances] + batch['pil_images'] = pil_images + + return batch + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, + data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=data_args.data_path, + data_args=data_args) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator) + + +def train(attn_implementation=None): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + + bnb_model_from_pretrained_args = {} + if training_args.bits in [4, 8]: + from transformers import BitsAndBytesConfig + bnb_model_from_pretrained_args.update(dict( + device_map={"": training_args.device}, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + quantization_config=BitsAndBytesConfig( + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + llm_int8_skip_modules=["mm_projector"], + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=training_args.double_quant, + bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} + ) + )) + + if model_args.vision_tower is not None: + if 'phi' in model_args.model_name_or_path.lower(): + model = ProbeDSGLlavaPhi3ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + elif 'qwen' in model_args.model_name_or_path.lower(): + model = ProbeDSGLlavaQwen2ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + model = ProbeDSGLlavaLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + if 'phi' in model_args.model_name_or_path.lower(): + model = transformers.Phi3ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + elif 'qwen2' in model_args.model_name_or_path.lower(): + model = transformers.Qwen2ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + model = transformers.LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + model.config.use_cache = False + + if model_args.freeze_backbone: + model.model.requires_grad_(False) + + if training_args.bits in [4, 8]: + from peft import prepare_model_for_kbit_training + model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) + + if training_args.gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if training_args.lora_enable: + from peft import LoraConfig, get_peft_model + lora_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=find_all_linear_names(model), + lora_dropout=training_args.lora_dropout, + bias=training_args.lora_bias, + task_type="CAUSAL_LM", + ) + if training_args.bits == 16: + if training_args.bf16: + model.to(torch.bfloat16) + if training_args.fp16: + model.to(torch.float16) + rank0_print("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + + if "qwen" in model_args.model_name_or_path.lower(): + tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right") + else: + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.unk_token + if tokenizer.pad_token_id is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=""), + tokenizer=tokenizer, + model=model, + ) + + if model_args.version in conversation_lib.conv_templates: + conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] + else: + conversation_lib.default_conversation = conversation_lib.conv_templates["llava_phi_3"] + + + vision_tower = model.get_vision_tower() + + if vision_tower is None: + model.get_model().initialize_vision_modules( + model_args=model_args, + fsdp=training_args.fsdp + ) + vision_tower = model.get_vision_tower() + + if not vision_tower.is_loaded: + vision_tower.load_model() + vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + data_args.image_processor = vision_tower.image_processor + data_args.is_multimodal = True + data_args.mode = model_args.mode + + if not hasattr(model.config, "image_grid_pinpoints"): + model.config.image_grid_pinpoints = [[336,672], [672,336], [672,672], [1008,336], [336,1008]] + model.config.image_aspect_ratio = data_args.image_aspect_ratio + else: + data_args.image_aspect_ratio = model.config.image_aspect_ratio + data_args.image_grid_pinpoints = model.config.image_grid_pinpoints + + model.config.tokenizer_padding_side = tokenizer.padding_side + model.config.tokenizer_model_max_length = tokenizer.model_max_length + + model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter + + if training_args.bits in [4, 8]: + model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) + + model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_projector_lr = training_args.mm_projector_lr + training_args.use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token + model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) + + model.config.use_s2 = model_args.use_s2 + model.config.s2_scales = model_args.s2_scales + model.config.probe_mode = model_args.mode + + model.requires_grad_(False) + + if model.config.probe_mode == "gen": + model.config.image_gen = { + "depth": model_args.img_head_depth, + "dim_head": model_args.img_head_dim_head, + "num_heads": model_args.img_head_num_heads, + "num_tokens": model_args.img_head_num_tokens, + "output_dim": model_args.img_head_output_dim, + "ff_mult": model_args.img_head_ff_mult, + } + model.config.image_generator = model_args.image_generator + elif model.config.probe_mode == "seg": + model.config.image_seg = { + "depth": model_args.seg_head_depth, + "dim_head": model_args.seg_head_dim_head, + "num_heads": model_args.seg_head_num_heads, + "num_tokens": model_args.seg_head_num_tokens, + "output_dim": model_args.seg_head_output_dim, + "ff_mult": model_args.seg_head_ff_mult, + "seg_teacher": model_args.seg_teacher, + } + model.config.image_segmentor = model_args.image_segmentor + elif model.config.probe_mode == "depth": + model.config.image_depth = { + "depth": model_args.depth_head_depth, + "dim_head": model_args.depth_head_dim_head, + "num_heads": model_args.depth_head_num_heads, + "num_tokens": model_args.depth_head_num_tokens, + "output_dim": model_args.depth_head_output_dim, + "ff_mult": model_args.depth_head_ff_mult, + } + model.config.depth_estimator = model_args.depth_estimator + + model.init_heads(model.config) + + for p in model.lm_head.parameters(): + p.requires_grad = False + + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = False + + import torch.distributed as dist + from icecream import ic + if dist.get_rank() == 0: + gen_heads = 0 + depth_heads = 0 + seg_heads = 0 + for n, p in model.named_parameters(): + if p.requires_grad: + if "gen_head" in n: + gen_heads += p.numel() + elif "depth_head" in n: + depth_heads += p.numel() + elif "seg_head" in n: + seg_heads += p.numel() + ic(n) + ic(depth_heads, gen_heads, seg_heads) + + if training_args.bits in [4, 8]: + from peft.tuners.lora import LoraLayer + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if training_args.bf16: + module = module.to(torch.bfloat16) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if training_args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + data_module = make_supervised_data_module(tokenizer=tokenizer, + data_args=data_args) + trainer = LLaVATrainer(model=model, + tokenizer=tokenizer, + args=training_args, + **data_module) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + model.config.use_cache = True + + if training_args.lora_enable: + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), training_args.lora_bias + ) + non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( + model.named_parameters() + ) + if training_args.local_rank == 0 or training_args.local_rank == -1: + model.config.save_pretrained(training_args.output_dir) + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) + else: + safe_save_model_for_hf_trainer(trainer=trainer, + output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/ola_vlm/train/probe_dsg_train_mem.py b/ola_vlm/train/probe_dsg_train_mem.py new file mode 100644 index 0000000000000000000000000000000000000000..d83955e174dbe0a75a392d1dae16c71a5968bafa --- /dev/null +++ b/ola_vlm/train/probe_dsg_train_mem.py @@ -0,0 +1,8 @@ +from ola_vlm.train.probe_dsg_train import train +import torch.multiprocessing as mp + +if __name__ == "__main__": + # try: + # train(attn_implementation="flash_attention_2") + # except: + train(attn_implementation="eager") \ No newline at end of file diff --git a/ola_vlm/train/sherlock_dsg_train.py b/ola_vlm/train/sherlock_dsg_train.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e2cce69555f223d6b98b228265cf22152794e2 --- /dev/null +++ b/ola_vlm/train/sherlock_dsg_train.py @@ -0,0 +1,1683 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +from dataclasses import dataclass, field +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List + +import numpy as np +import torch + +import transformers +import tokenizers + +from ola_vlm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from torch.utils.data import Dataset +from ola_vlm.train.llava_trainer import LLaVATrainer + +from ola_vlm import conversation as conversation_lib +from ola_vlm.model import * +from ola_vlm.mm_utils import tokenizer_image_token + +from PIL import Image, ImageFile +from transformers import set_seed + +set_seed(42) + +# Enable loading of truncated images +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +from packaging import version +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_mm_mlp_adapter: bool = field(default=False) + unfreeze_mm_vision_tower: bool = field(default=False) + unfreeze_whole_model: bool = field(default=False) + use_s2: bool = field(default=False) + s2_scales: Optional[str] = field(default="336,1008") + vision_tower: Optional[str] = field(default=None) + mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + pretrain_mm_mlp_adapter: Optional[str] = field(default=None) + mm_projector_type: Optional[str] = field(default='linear') + mm_use_im_start_end: bool = field(default=False) + mm_use_im_patch_token: bool = field(default=True) + mm_patch_merge_type: Optional[str] = field(default='flat') + mm_vision_select_feature: Optional[str] = field(default="patch") + + attn_mask_type: Optional[str] = field(default="causal") + + contrastive_loss_weight: Optional[float] = field(default=0.1) + + # visual interpretors + image_generator: Optional[str] = field(default="stabilityai/stable-diffusion-2-1-unclip") + image_segmentor: Optional[str] = field(default="shi-labs/oneformer_coco_swin_large") # sam_vit_l_0b3195.pth + depth_estimator: Optional[str] = field(default="depth_anything_v2_vitl.pth") + + mode: Optional[str] = field(default="depth-seg-gen") + num_task_tokens: Optional[int] = 0 + task_token_format: Optional[str] = "expand_emb" + sample_tokens: Optional[bool] = False + pass_text_to_aux: Optional[bool] = False + + # dinov2 + use_dinov2: Optional[bool] = False + dinov2_model: Optional[str] = "/mnt/projects4jw/jiteshjain_sherlock/dinov2-large-res336" + dinov2_dim: Optional[str] = 1024 + dinov2_layers: Optional[str] = "8-12" + dinov2_loss_weight: Optional[float] = 0.25 + + use_contrastive: Optional[bool] = True + use_ce: Optional[bool] = False + layer_indices: Optional[str] = "d8-14_s10-16_g12-18" + loss_weights: Optional[str] = "d0.5_s0.5_g0.5" + + # gen + img_head_depth: Optional[int] = 1 + img_head_dim_head: Optional[int] = 32 + img_head_num_heads: Optional[int] = 4 + img_head_num_tokens: Optional[int] = 1 + img_head_output_dim: Optional[int] = 1024 + img_head_ff_mult: Optional[int] = 1 + + # seg + seg_head_depth: Optional[int] = 1 + seg_head_dim_head: Optional[int] = 32 + seg_head_num_heads: Optional[int] = 4 + seg_head_num_tokens: Optional[int] = 576 + seg_head_output_dim: Optional[int] = 1536 # 256 + seg_head_ff_mult: Optional[int] = 1 + seg_teacher: Optional[str] = "oneformer" # "sam" + + # depth + depth_head_depth: Optional[int] = 1 + depth_head_dim_head: Optional[int] = 32 + depth_head_num_heads: Optional[int] = 4 + depth_head_num_tokens: Optional[int] = 576 + depth_head_output_dim: Optional[int] = 1024 + depth_head_ff_mult: Optional[int] = 1 + use_intermediate_depth: Optional[bool] = False + + freeze_task_token: Optional[bool] = field(default=False) + freeze_aux_heads: Optional[bool] = field(default=False) + use_reference_model: Optional[bool] = field(default=False) + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + lazy_preprocess: bool = False + is_multimodal: bool = False + image_folder: Optional[str] = field(default=None) + depth_folder: Optional[str] = field(default=None) + unclip_folder: Optional[str] = field(default=None) + seg_folder: Optional[str] = field(default=None) + image_aspect_ratio: str = 'square' + use_cost: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + freeze_mm_mlp_adapter: bool = field(default=False) + mpt_attn_impl: Optional[str] = field(default="triton") + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + mm_projector_lr: Optional[float] = None + mm_vision_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, + output_dir: str): + """Collects the state dict and dump to disk.""" + + if getattr(trainer.args, "tune_mm_mlp_adapter", False): + # Only save Adapter + keys_to_match = ['mm_projector'] + if getattr(trainer.args, "use_im_start_end", False): + keys_to_match.extend(['embed_tokens', 'embed_in']) + + weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) + trainer.model.config.save_pretrained(output_dir) + + current_folder = output_dir.split('/')[-1] + parent_folder = os.path.dirname(output_dir) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + if current_folder.startswith('checkpoint-'): + mm_projector_folder = os.path.join(parent_folder, "mm_projector") + os.makedirs(mm_projector_folder, exist_ok=True) + torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) + else: + torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) + + if trainer.deepspeed: + torch.cuda.synchronize() + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = { + key: value.cpu() + for key, value in state_dict.items() + } + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn(strings: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [ + tokenized.input_ids[0] for tokenized in tokenized_list + ] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = 'unknown' + sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + + sentence["value"] + END_SIGNAL) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def preprocess_multimodal( + sources: Sequence[str], + data_args: DataArguments +) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + if DEFAULT_IMAGE_TOKEN in sentence['value']: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() + sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] + sentence['value'] = sentence['value'].strip() + if "mmtag" in conversation_lib.default_conversation.version: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') + replace_token = DEFAULT_IMAGE_TOKEN + if data_args.mm_use_im_start_end: + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + return sources + + +def preprocess_phi_3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i > 0: + round_len -= 2 + instruction_len -= 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + return dict( + input_ids=input_ids, + labels=targets, + ) + +def preprocess_llama_3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack( + [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx + 2])) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + # if i > 0: + # round_len -= 1 + # instruction_len -= 1 + + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + +def preprocess_llama_2( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 + + # Mask targets + sep = "[/INST] " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_v1( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + return dict( + input_ids=input_ids, + labels=targets, + ) + +def preprocess_qwen( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False, + system_message: str = "You are a helpful assistant." + ) -> Dict: + # roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"} + roles = {"human": "user", "gpt": "assistant"} + + # Add image tokens to tokenizer as a special tokens + # Use a deepcopy of tokenizer so that we don't modify on the tokenizer + tokenizer = copy.deepcopy(tokenizer) + # When there is actually an image, we add the image tokens as a special token + if has_image: + tokenizer.add_tokens([""], special_tokens=True) + + image_token_index = tokenizer.convert_tokens_to_ids("") + im_start, im_end = tokenizer.additional_special_tokens_ids + # unmask_tokens = ["<|im_start|>", "<|im_start|>", "\n"] + unmask_tokens_idx = [198, im_start, im_end] + nl_tokens = tokenizer("\n").input_ids + + # Reset Qwen chat templates so that it won't include system message every time we apply + chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + tokenizer.chat_template = chat_template + + # _system = tokenizer("system").input_ids + nl_tokens + # _user = tokenizer("user").input_ids + nl_tokens + # _assistant = tokenizer("assistant").input_ids + nl_tokens + + # Apply prompt templates + input_ids, targets = [], [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != roles["human"]: + source = source[1:] + + input_id, target = [], [] + + # New version, use apply chat template + # Build system message for each sentence + input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) + target += [IGNORE_INDEX] * len(input_id) + + for conv in source: + # Make sure llava data can load + try: + role = conv["role"] + content = conv["content"] + except: + role = conv["from"] + content = conv["value"] + + role = roles.get(role, role) + + conv = [{"role" : role, "content" : content}] + encode_id = tokenizer.apply_chat_template(conv) + input_id += encode_id + if role in ["user", "system"]: + target += [IGNORE_INDEX] * len(encode_id) + else: + target += encode_id + + assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" + for idx, encode_id in enumerate(input_id): + if encode_id in unmask_tokens_idx: + target[idx] = encode_id + if encode_id == image_token_index: + input_id[idx] = IMAGE_TOKEN_INDEX + input_ids.append(input_id) + targets.append(target) + input_ids = torch.tensor(input_ids, dtype=torch.long) + targets = torch.tensor(targets, dtype=torch.long) + + return dict( + input_ids=input_ids, # tensor(bs x seq_len) + labels=targets, # tensor(bs x seq_len) + ) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.version == "llama3": + return preprocess_llama_3(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "phi3": + return preprocess_phi_3(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "qwen": + return preprocess_qwen(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: + return preprocess_llama_2(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_image=has_image) + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{conversation_lib.default_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + # tokenize conversations + def get_tokenize_len(prompts): + return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] + + if has_image: + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + else: + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + if has_image: + tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) + else: + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + + +def read_jsonl(path): + list_data_dict = [] + + with open(path, "r") as file: + for line in file: + d = json.loads(line) + list_data_dict.append(d) + return list_data_dict + + +def _obtain_seg_texts(file_path): + def _remove_specific_word(text, word_to_remove): + import re + tokens = re.findall(r'\b\w+\b|[,.]', text) + result_tokens = [] + word_found = False + + for i, token in enumerate(tokens): + if token == word_to_remove: + if not word_found: + # Keep the first occurrence and mark it as found + result_tokens.append(token) + word_found = True + else: + # Remove any preceding punctuation if it's just before this word + if i > 0 and tokens[i-1] in {',', '.'}: + result_tokens.pop() + else: + result_tokens.append(token) + + # Join tokens and clean up spaces before punctuation + result_text = ' '.join(result_tokens) + result_text = re.sub(r'\s([,.](?:\s|$))', r'\1', result_text) + return result_text + + with open(file_path) as f: + lines = f.readlines() + + seg_labels = {} + for line in lines: + key = line.split("")[1].strip("\n") + label = line.split("")[2].strip("\n") + label = _remove_specific_word(label, "wall") + label = _remove_specific_word(label, "window") + seg_labels[key] = label + + return seg_labels + +from ola_vlm.ola_utils import PANOPTIC_QUESTIONS, SEMANTIC_QUESTIONS, INSTANCE_QUESTIONS +import random +def get_object_data_split(data_args): + list_data_dict = [] + for bucket in ["train"]: + panoptic_labels = _obtain_seg_texts(os.path.join(data_args.image_folder, "coco", "panoptic.txt")) + semantic_labels = _obtain_seg_texts(os.path.join(data_args.image_folder, "coco", "semantic.txt")) + instance_labels = _obtain_seg_texts(os.path.join(data_args.image_folder, "coco", "instance.txt")) + + for key in panoptic_labels.keys(): + assert key in semantic_labels.keys() and key in instance_labels.keys(), "Instance, semantic, and panoptic labels should have the same keys." + prob_task = np.random.uniform(0,1.) + question_prob = np.random.uniform(0,1.) + if prob_task < 0.33: + answer = semantic_labels[key] + if question_prob > 0.90: + question = "What objects can be seen in the image?" + else: + question = random.choice(SEMANTIC_QUESTIONS) + elif prob_task < 0.66: + answer = instance_labels[key] + if question_prob > 0.90: + question = "What objects can be seen in the image?" + else: + question = random.choice(INSTANCE_QUESTIONS) + else: + answer = panoptic_labels[key] + if question_prob > 0.90: + question = "What objects can be seen in the image?" + else: + question = random.choice(PANOPTIC_QUESTIONS) + + question += "\n" + conversations = [ + { + "from": "human", + "value": question + }, + { + "from": "gpt", + "value": answer + }, + ] + list_data_dict.append( + { + "conversations": conversations, + "image": "coco/" + bucket + "2017/" + key, + } + ) + + random.shuffle(list_data_dict) + return list_data_dict + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments): + super(LazySupervisedDataset, self).__init__() + if "jsonl" in data_path: + list_data_dict = read_jsonl(data_path) + else: + list_data_dict = json.load(open(data_path, "r")) + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.data_args = data_args + + if data_args.use_cost: + cost_list_data = get_object_data_split(data_args) + self.list_data_dict.extend(cost_list_data) + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + img_tokens = 128 if 'image' in sample else 0 + length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) + return length_list + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) + cur_len = cur_len if 'image' in sample else -cur_len + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + + if 'image' in sources[0]: + image_file = self.list_data_dict[i]['image'] + image_folder = self.data_args.image_folder + processor = self.data_args.image_processor + try: + crop_size = self.data_args.image_processor.crop_size + except: + crop_size = self.data_args.image_processor.size + + try: + image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + pil_image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + except Exception as e: + from icecream import ic + ic("----------------------------------") + ic("OS ERROROROROROROROROROROOR") + ic("OS ERROROROROROROROROROROOR") + ic(image_file) + ic(e) + ic("OS ERROROROROROROROROROROOR") + ic("OS ERROROROROROROROROROROOR") + ic("===================================") + return self.__getitem__(0) + + if self.data_args.image_aspect_ratio == 'pad': + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + else: + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + sources = preprocess_multimodal( + copy.deepcopy([e["conversations"] for e in sources]), + self.data_args) + else: + sources = copy.deepcopy([e["conversations"] for e in sources]) + + data_dict = preprocess( + sources, + self.tokenizer, + has_image=('image' in self.list_data_dict[i])) + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], + labels=data_dict["labels"][0]) + + # image exist in the data + if 'image' in self.list_data_dict[i]: + data_dict['image'] = image + data_dict['pil_image'] = pil_image + data_dict['seg_mask'] = 1 + data_dict['depth_mask'] = 1 + data_dict['gen_mask'] = 1 + elif self.data_args.is_multimodal: + try: + crop_size = self.data_args.image_processor.crop_size + except: + crop_size = self.data_args.image_processor.size + data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) + data_dict['pil_image'] = Image.new('RGB', (crop_size['width'], crop_size['height']), color='black') + data_dict['seg_mask'] = 0 + data_dict['depth_mask'] = 0 + data_dict['gen_mask'] = 0 + + return data_dict + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] + for key in ("input_ids", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, + batch_first=True, + padding_value=IGNORE_INDEX) + input_ids = input_ids[:, :self.tokenizer.model_max_length] + labels = labels[:, :self.tokenizer.model_max_length] + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + if 'image' in instances[0]: + images = [instance['image'] for instance in instances] + if all(x is not None and x.shape == images[0].shape for x in images): + batch['images'] = torch.stack(images) + else: + batch['images'] = images + + if 'pil_image' in instances[0]: + pil_images = [instance['pil_image'] for instance in instances] + batch['pil_images'] = pil_images + + seg_mask = [instance['seg_mask'] for instance in instances] + batch['seg_mask'] = torch.tensor(seg_mask) + + depth_mask = [instance['depth_mask'] for instance in instances] + batch['depth_mask'] = torch.tensor(depth_mask) + + gen_mask = [instance['gen_mask'] for instance in instances] + batch['gen_mask'] = torch.tensor(gen_mask) + + return batch + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, + data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=data_args.data_path, + data_args=data_args) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator) + +def add_special_tokens( + special_tokens: List, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Initialize new token embeddings to follow the distribution of existing embeddings. + """ + # Add special tokens to tokenizer + num_new_tokens = tokenizer.add_tokens(special_tokens, special_tokens=True) + # Resize the token embeddings in the model + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + # Get input embeddings and compute global mean and std over all dimensions + input_embeddings = model.get_input_embeddings().weight.data + input_mean, input_std = input_embeddings.mean(), input_embeddings.std() + + # Initialize new input embeddings with the same distribution as existing ones + input_embeddings[-num_new_tokens:] = torch.nn.init.normal_( + torch.empty(num_new_tokens, input_embeddings.size(1)), + mean=input_mean.item(), + std=input_std.item() + ) + + # Check if model has output embeddings and initialize them similarly + if model.get_output_embeddings() is not None: + output_embeddings = model.get_output_embeddings().weight.data + output_mean, output_std = output_embeddings.mean(), output_embeddings.std() + + # Initialize new output embeddings with the same distribution as existing ones + output_embeddings[-num_new_tokens:] = torch.nn.init.normal_( + torch.empty(num_new_tokens, output_embeddings.size(1)), + mean=output_mean.item(), + std=output_std.item() + ) + +def train(attn_implementation=None): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + + bnb_model_from_pretrained_args = {} + if training_args.bits in [4, 8]: + from transformers import BitsAndBytesConfig + bnb_model_from_pretrained_args.update(dict( + device_map={"": training_args.device}, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + quantization_config=BitsAndBytesConfig( + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + llm_int8_skip_modules=["mm_projector"], + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=training_args.double_quant, + bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} + ) + )) + + if model_args.vision_tower is not None: + if 'phi' in model_args.model_name_or_path.lower(): + model = OlaLlavaPhi3ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + elif 'qwen' in model_args.model_name_or_path.lower(): + model = OlaLlavaQwenForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + model = OlaLlavaLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + if 'phi' in model_args.model_name_or_path.lower(): + model = transformers.Phi3ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + elif 'qwen2' in model_args.model_name_or_path.lower(): + model = transformers.Qwen2ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + model = transformers.LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + model.config.use_cache = False + + if model_args.freeze_backbone: + model.model.requires_grad_(False) + + if training_args.bits in [4, 8]: + from peft import prepare_model_for_kbit_training + model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) + + if training_args.gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if training_args.lora_enable: + from peft import LoraConfig, get_peft_model + lora_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=find_all_linear_names(model), + lora_dropout=training_args.lora_dropout, + bias=training_args.lora_bias, + task_type="CAUSAL_LM", + ) + if training_args.bits == 16: + if training_args.bf16: + model.to(torch.bfloat16) + if training_args.fp16: + model.to(torch.float16) + rank0_print("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + + if "qwen" in model_args.model_name_or_path.lower(): + tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right") + else: + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.unk_token + if tokenizer.pad_token_id is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=""), + tokenizer=tokenizer, + model=model, + ) + + if model_args.version in conversation_lib.conv_templates: + conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] + else: + conversation_lib.default_conversation = conversation_lib.conv_templates["llava_phi_3"] + + if "sherlock" in model_args.model_name_or_path: + vision_tower = model.get_vision_tower() + + if vision_tower is None: + model.get_model().initialize_vision_modules( + model_args=model_args, + fsdp=training_args.fsdp + ) + vision_tower = model.get_vision_tower() + + if not vision_tower.is_loaded: + vision_tower.load_model() + vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + elif model_args.vision_tower is not None: + model.get_model().initialize_vision_modules( + model_args=model_args, + fsdp=training_args.fsdp + ) + + vision_tower = model.get_vision_tower() + vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + data_args.image_processor = vision_tower.image_processor + data_args.is_multimodal = True + + model.config.image_grid_pinpoints = [[336,672], [672,336], [672,672], [1008,336], [336,1008]] + model.config.image_aspect_ratio = data_args.image_aspect_ratio + model.config.tokenizer_padding_side = tokenizer.padding_side + model.config.tokenizer_model_max_length = tokenizer.model_max_length + + model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter + if model_args.tune_mm_mlp_adapter: + model.requires_grad_(False) + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = True + + model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter + if training_args.freeze_mm_mlp_adapter: + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = False + + if training_args.bits in [4, 8]: + model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) + + model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_projector_lr = training_args.mm_projector_lr + training_args.use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token + model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) + + if model_args.unfreeze_mm_vision_tower: + model.requires_grad_(False) + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = True + model.get_model().vision_tower.requires_grad_(True) + else: + model.get_model().vision_tower.requires_grad_(False) + + model.config.use_s2 = model_args.use_s2 + model.config.s2_scales = model_args.s2_scales + + if "sherlock" not in model_args.model_name_or_path.split("/")[-1]: + aux_mode = model_args.mode + model.config.aux_mode = model_args.mode + model.config.contrastive_loss_weight = model_args.contrastive_loss_weight + model.config.num_task_tokens = model_args.num_task_tokens + model.config.task_token_format = model_args.task_token_format + model.config.pass_text_to_aux = model_args.pass_text_to_aux + model.config.use_contrastive = model_args.use_contrastive + model.config.use_ce = model_args.use_ce + + layer_indices = model_args.layer_indices + + pattern = r'[a-zA-Z]\d+(?:-\d+)?' + + import re + # Extract matching substrings from each string + matches = re.findall(pattern, layer_indices) + + depth_layer_indices = "0" + seg_layer_indices = "0" + img_layer_indices = "0" + + for match in matches: + if match.startswith('d'): + depth_layer_indices = match[1:] + elif match.startswith('s'): + seg_layer_indices = match[1:] + elif match.startswith('g'): + img_layer_indices = match[1:] + + loss_weights = model_args.loss_weights + + pattern = r'[a-zA-Z]\d+\.\d+' + matches = re.findall(pattern, loss_weights) + + img_loss_weight = 0.5 + seg_loss_weight = 0.5 + depth_loss_weight = 0.5 + + for match in matches: + if match.startswith('d'): + depth_loss_weight = float(match[1:]) + elif match.startswith('s'): + seg_loss_weight = float(match[1:]) + elif match.startswith('g'): + img_loss_weight = float(match[1:]) + + model.config.image_gen = { + "depth": model_args.img_head_depth, + "dim_head": model_args.img_head_dim_head, + "num_heads": model_args.img_head_num_heads, + "num_tokens": model_args.img_head_num_tokens, + "output_dim": model_args.img_head_output_dim, + "ff_mult": model_args.img_head_ff_mult, + "img_layer_indices": img_layer_indices, + "img_loss_weight": img_loss_weight, + } + model.config.image_generator = model_args.image_generator + + model.config.image_seg = { + "depth": model_args.seg_head_depth, + "dim_head": model_args.seg_head_dim_head, + "num_heads": model_args.seg_head_num_heads, + "num_tokens": model_args.seg_head_num_tokens, + "output_dim": model_args.seg_head_output_dim, + "ff_mult": model_args.seg_head_ff_mult, + "seg_layer_indices": seg_layer_indices, + "seg_loss_weight": seg_loss_weight, + "seg_teacher": model_args.seg_teacher, + } + model.config.image_segmentor = model_args.image_segmentor + + model.config.image_depth = { + "depth": model_args.depth_head_depth, + "dim_head": model_args.depth_head_dim_head, + "num_heads": model_args.depth_head_num_heads, + "num_tokens": model_args.depth_head_num_tokens, + "output_dim": model_args.depth_head_output_dim, + "ff_mult": model_args.depth_head_ff_mult, + "depth_layer_indices": depth_layer_indices, + "depth_loss_weight": depth_loss_weight, + "use_intermediate_depth": model_args.use_intermediate_depth, + } + model.config.depth_estimator = model_args.depth_estimator + model.config.sample_tokens = model_args.sample_tokens + num_task_tokens = model_args.num_task_tokens + + if model_args.use_dinov2: + model.config.dinov2_feats = { + "model": model_args.dinov2_model, + "dinov2_layer_indices": model_args.dinov2_layers, + "dim": model_args.dinov2_dim, + "dinov2_loss_weight": model_args.dinov2_loss_weight, + } + + model.config.num_task_tokens = model_args.num_task_tokens + model.config.task_token_format = model_args.task_token_format + if model_args.num_task_tokens > 0: + if model_args.task_token_format == "text": + if "depth" in aux_mode: + special_depth_tokens = [f"" for i in range(num_task_tokens)] + special_depth_tokens_str = "".join(special_depth_tokens) + add_special_tokens( + special_tokens=special_depth_tokens, + tokenizer=tokenizer, + model=model, + ) + model.config.depth_tokens = tokenizer(special_depth_tokens_str).input_ids[1:] + if "seg" in aux_mode: + special_seg_tokens = [f"" for i in range(num_task_tokens)] + special_seg_tokens_str = "".join(special_seg_tokens) + add_special_tokens( + special_tokens=special_seg_tokens, + tokenizer=tokenizer, + model=model, + ) + model.config.seg_tokens = tokenizer(special_seg_tokens_str).input_ids[1:] + if "gen" in aux_mode: + special_gen_tokens = [f"" for i in range(num_task_tokens)] + special_gen_tokens_str = "".join(special_gen_tokens) + add_special_tokens( + special_tokens=special_gen_tokens, + tokenizer=tokenizer, + model=model, + ) + model.config.gen_tokens = tokenizer(special_gen_tokens_str).input_ids[1:] + + model.get_model().initialize_special_tokens(model.config) + + model.init_heads(model.config) + model.init_target_models(model.config) + elif model_args.unfreeze_whole_model: + model.requires_grad_(True) + elif model_args.unfreeze_mm_vision_tower: + if "depth" in model_args.mode: + for p in model.image_depth_heads.parameters(): + p.requires_grad = True + if "gen" in model_args.mode: + for p in model.image_gen_heads.parameters(): + p.requires_grad = True + if "seg" in model_args.mode: + for p in model.image_seg_heads.parameters(): + p.requires_grad = True + if "emb" in model.config.task_token_format and model.config.num_task_tokens > 0: + if "gen" in aux_mode: + model.get_model().special_gen_tokens.requires_grad_(True) + if "seg" in aux_mode: + model.get_model().special_seg_tokens.requires_grad_(True) + if "depth" in aux_mode: + model.get_model().special_depth_tokens.requires_grad_(True) + elif not model_args.tune_mm_mlp_adapter: + if "emb" in model.config.task_token_format and model.config.num_task_tokens > 0: + if "gen" in model.config.aux_mode: + model.get_model().special_gen_tokens.requires_grad_(False) + if "seg" in model.config.aux_mode: + model.get_model().special_seg_tokens.requires_grad_(False) + if "depth" in model.config.aux_mode: + model.get_model().special_depth_tokens.requires_grad_(False) + + + loss_weights = model_args.loss_weights + + import re + pattern = r'[a-zA-Z]\d+\.\d+' + matches = re.findall(pattern, loss_weights) + + img_loss_weight = 0.5 + seg_loss_weight = 0.5 + depth_loss_weight = 0.5 + + for match in matches: + if match.startswith('d'): + depth_loss_weight = float(match[1:]) + elif match.startswith('s'): + seg_loss_weight = float(match[1:]) + elif match.startswith('g'): + img_loss_weight = float(match[1:]) + + model.config.image_seg["seg_loss_weight"] = seg_loss_weight + model.config.image_gen["img_loss_weight"] = img_loss_weight + model.config.image_depth["depth_loss_weight"] = depth_loss_weight + + if model_args.use_reference_model: + model.init_reference_model() + + for name, p in model.named_parameters(): + if "sam." in name or "da_v2_head." in name or "dinov2_model." in name or "gen_encoder." in name or "dav2_backbone." in name or "oneformer." in name: + p.requires_grad = False + + model.img_gen_loss_weight = img_loss_weight + model.img_seg_loss_weight = seg_loss_weight + model.img_depth_loss_weight = depth_loss_weight + + if model_args.num_task_tokens > 0: + if "emb" in model.config.task_token_format and model_args.freeze_task_token: + if "gen" in model.config.aux_mode: + model.get_model().special_gen_tokens.requires_grad_(False) + if "seg" in model.config.aux_mode: + model.get_model().special_seg_tokens.requires_grad_(False) + if "depth" in model.config.aux_mode: + model.get_model().special_depth_tokens.requires_grad_(False) + else: + if "gen" in model.config.aux_mode: + model.get_model().special_gen_tokens.requires_grad_(True) + if "seg" in model.config.aux_mode: + model.get_model().special_seg_tokens.requires_grad_(True) + if "depth" in model.config.aux_mode: + model.get_model().special_depth_tokens.requires_grad_(True) + + if model_args.freeze_aux_heads: + model.get_model().vision_tower.requires_grad_(False) + if "depth" in model.config.aux_mode: + for p in model.image_depth_heads.parameters(): + p.requires_grad = False + model.depth_logit_scale.requires_grad_(False) + if "gen" in model.config.aux_mode: + for p in model.image_gen_heads.parameters(): + p.requires_grad = False + model.gen_logit_scale.requires_grad_(False) + if "seg" in model.config.aux_mode: + for p in model.image_seg_heads.parameters(): + p.requires_grad = False + model.seg_logit_scale.requires_grad_(False) + + import torch.distributed as dist + from icecream import ic + if dist.get_rank() == 0: + gen_heads = 0 + depth_heads = 0 + seg_heads = 0 + for n, p in model.named_parameters(): + if p.requires_grad: + if "gen_head" in n: + gen_heads += p.numel() + elif "depth_head" in n: + depth_heads += p.numel() + elif "seg_head" in n: + seg_heads += p.numel() + ic(n) + ic(depth_heads, gen_heads, seg_heads) + + if training_args.bits in [4, 8]: + from peft.tuners.lora import LoraLayer + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if training_args.bf16: + module = module.to(torch.bfloat16) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if training_args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + data_module = make_supervised_data_module(tokenizer=tokenizer, + data_args=data_args) + trainer = LLaVATrainer(model=model, + tokenizer=tokenizer, + args=training_args, + **data_module) + + print('starting training...', local_rank) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + model.config.use_cache = True + + if training_args.lora_enable: + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), training_args.lora_bias + ) + non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( + model.named_parameters() + ) + if training_args.local_rank == 0 or training_args.local_rank == -1: + model.config.save_pretrained(training_args.output_dir) + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) + else: + safe_save_model_for_hf_trainer(trainer=trainer, + output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/ola_vlm/train/sherlock_dsg_train_mem.py b/ola_vlm/train/sherlock_dsg_train_mem.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c32d57e3f17f0e68f0b2e2975fd37fdd270fee --- /dev/null +++ b/ola_vlm/train/sherlock_dsg_train_mem.py @@ -0,0 +1,9 @@ +from ola_vlm.train.sherlock_dsg_train import train +import torch.multiprocessing as mp + +if __name__ == "__main__": + # mp.set_start_method('spawn') + # try: + # train(attn_implementation="flash_attention_2") + # except: + train(attn_implementation="eager") \ No newline at end of file diff --git a/ola_vlm/train/train.py b/ola_vlm/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b364ffb534c2c96f1f04a3b5f8e60dbc59e4b5f4 --- /dev/null +++ b/ola_vlm/train/train.py @@ -0,0 +1,1392 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +from dataclasses import dataclass, field +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List + +import torch + +import transformers +import tokenizers + +from ola_vlm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from torch.utils.data import Dataset +from ola_vlm.train.llava_trainer import LLaVATrainer + +from llava import conversation as conversation_lib +from ola_vlm.model import * +from ola_vlm.mm_utils import tokenizer_image_token + +from PIL import Image, ImageFile +from transformers import set_seed + +set_seed(42) + +# Enable loading of truncated images +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +from packaging import version +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_mm_mlp_adapter: bool = field(default=False) + unfreeze_mm_vision_tower: bool = field(default=False) + unfreeze_whole_model: bool = field(default=False) + use_s2: bool = field(default=False) + s2_scales: Optional[str] = field(default="336,1008") + vision_tower: Optional[str] = field(default=None) + mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + pretrain_mm_mlp_adapter: Optional[str] = field(default=None) + mm_projector_type: Optional[str] = field(default='linear') + mm_use_im_start_end: bool = field(default=False) + mm_use_im_patch_token: bool = field(default=True) + mm_patch_merge_type: Optional[str] = field(default='flat') + mm_vision_select_feature: Optional[str] = field(default="patch") + + attn_mask_type: Optional[str] = field(default="causal") + freeze_task_token: Optional[bool] = field(default=True) + + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + lazy_preprocess: bool = False + is_multimodal: bool = False + image_folder: Optional[str] = field(default=None) + image_aspect_ratio: str = 'square' + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + freeze_mm_mlp_adapter: bool = field(default=False) + mpt_attn_impl: Optional[str] = field(default="triton") + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + mm_projector_lr: Optional[float] = None + mm_vision_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, + output_dir: str): + """Collects the state dict and dump to disk.""" + + if getattr(trainer.args, "tune_mm_mlp_adapter", False): + # Only save Adapter + keys_to_match = ['mm_projector'] + if getattr(trainer.args, "use_im_start_end", False): + keys_to_match.extend(['embed_tokens', 'embed_in']) + + weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) + trainer.model.config.save_pretrained(output_dir) + + current_folder = output_dir.split('/')[-1] + parent_folder = os.path.dirname(output_dir) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + if current_folder.startswith('checkpoint-'): + mm_projector_folder = os.path.join(parent_folder, "mm_projector") + os.makedirs(mm_projector_folder, exist_ok=True) + torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) + else: + torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) + + if trainer.deepspeed: + torch.cuda.synchronize() + + # print("Loading ckpts...") + # from safetensors import safe_open + # checkpoint_path = "/mnt/projects4jw/jiteshjain_sherlock/models/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup-res768-phi3-4k-mini-v-pretrain" + # safetensors_paths = [os.path.join(checkpoint_path, f) for f in os.listdir(checkpoint_path) if f.endswith('.safetensors')] + + # state_dict1 = {} + # for path in safetensors_paths: + # with safe_open(path, framework="pt", device="cpu") as f: + # for key in f.keys(): + # state_dict1[key] = f.get_tensor(key) + # # model.load_state_dict(state_dict, strict=False) + # state_dict = trainer.accelerator.get_state_dict(trainer.deepspeed) + # from tqdm import tqdm + # for param_name in tqdm(state_dict.keys()): + # if "vision_tower" in param_name: + # checkpoint_value = state_dict1[param_name] + # model_value = state_dict[param_name] + + # from icecream import ic + # if not torch.equal(checkpoint_value.to(device=model_value.device, dtype=model_value.dtype), model_value): + # ic(param_name, checkpoint_value.mean(), model_value.mean()) + + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = { + key: value.cpu() + for key, value in state_dict.items() + } + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn(strings: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [ + tokenized.input_ids[0] for tokenized in tokenized_list + ] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = 'unknown' + sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + + sentence["value"] + END_SIGNAL) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def preprocess_multimodal( + sources: Sequence[str], + data_args: DataArguments +) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + if DEFAULT_IMAGE_TOKEN in sentence['value']: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() + sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] + sentence['value'] = sentence['value'].strip() + if "mmtag" in conversation_lib.default_conversation.version: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') + replace_token = DEFAULT_IMAGE_TOKEN + if data_args.mm_use_im_start_end: + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + return sources + + +def preprocess_phi_3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i > 0: + round_len -= 2 + instruction_len -= 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + +def preprocess_llama_3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack( + [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx + 2])) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + # if i > 0: + # round_len -= 1 + # instruction_len -= 1 + + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + +def preprocess_qwen( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False, + system_message: str = "You are a helpful assistant." + ) -> Dict: + # roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"} + roles = {"human": "user", "gpt": "assistant"} + + # Add image tokens to tokenizer as a special tokens + # Use a deepcopy of tokenizer so that we don't modify on the tokenizer + tokenizer = copy.deepcopy(tokenizer) + # When there is actually an image, we add the image tokens as a special token + if has_image: + tokenizer.add_tokens([""], special_tokens=True) + + image_token_index = tokenizer.convert_tokens_to_ids("") + im_start, im_end = tokenizer.additional_special_tokens_ids + # unmask_tokens = ["<|im_start|>", "<|im_start|>", "\n"] + unmask_tokens_idx = [198, im_start, im_end] + nl_tokens = tokenizer("\n").input_ids + + # Reset Qwen chat templates so that it won't include system message every time we apply + chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + tokenizer.chat_template = chat_template + + # _system = tokenizer("system").input_ids + nl_tokens + # _user = tokenizer("user").input_ids + nl_tokens + # _assistant = tokenizer("assistant").input_ids + nl_tokens + + # Apply prompt templates + input_ids, targets = [], [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != roles["human"]: + source = source[1:] + + input_id, target = [], [] + + # New version, use apply chat template + # Build system message for each sentence + input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) + target += [IGNORE_INDEX] * len(input_id) + + for conv in source: + # Make sure llava data can load + try: + role = conv["role"] + content = conv["content"] + except: + role = conv["from"] + content = conv["value"] + + role = roles.get(role, role) + + conv = [{"role" : role, "content" : content}] + encode_id = tokenizer.apply_chat_template(conv) + input_id += encode_id + if role in ["user", "system"]: + target += [IGNORE_INDEX] * len(encode_id) + else: + target += encode_id + + assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" + for idx, encode_id in enumerate(input_id): + if encode_id in unmask_tokens_idx: + target[idx] = encode_id + if encode_id == image_token_index: + input_id[idx] = IMAGE_TOKEN_INDEX + input_ids.append(input_id) + targets.append(target) + input_ids = torch.tensor(input_ids, dtype=torch.long) + targets = torch.tensor(targets, dtype=torch.long) + + return dict( + input_ids=input_ids, # tensor(bs x seq_len) + labels=targets, # tensor(bs x seq_len) + ) + + +# def preprocess_llama_3( +# sources, +# tokenizer: transformers.PreTrainedTokenizer, +# has_image: bool = False, +# max_len=2048, +# system_message: str = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", +# ) -> Dict: +# # roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"} +# roles = {"human": "user", "gpt": "assistant"} + +# # Add image tokens to tokenizer as a special tokens +# # Use a deepcopy of tokenizer so that we don't modify on the tokenizer +# tokenizer = copy.deepcopy(tokenizer) +# # When there is actually an image, we add the image tokens as a special token +# if has_image: +# tokenizer.add_tokens([""], special_tokens=True) +# image_token_index = tokenizer.convert_tokens_to_ids("") +# bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>") +# start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>") +# end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>") +# eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") + +# unmask_tokens = ["<|begin_of_text|>", "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>", "\n\n"] +# unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens] + +# # After update, calling tokenizer of llama3 will +# # auto add bos id for the tokens. ヽ(ο½€βŒ’Β΄)οΎ‰ +# def safe_tokenizer_llama3(text): +# input_ids = tokenizer(text).input_ids +# if input_ids[0] == bos_token_id: +# input_ids = input_ids[1:] +# return input_ids + +# nl_tokens = tokenizer.convert_tokens_to_ids("\n\n") +# # Apply prompt templates +# input_ids, targets = [], [] +# for i, source in enumerate(sources): +# if roles[source[0]["from"]] != roles["human"]: +# source = source[1:] + +# input_id, target = [], [] + +# # New version, use apply chat template +# # Build system message for each sentence +# input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) +# target += [IGNORE_INDEX] * len(input_id) + +# for conv in source: +# # Make sure llava data can load +# try: +# role = conv["role"] +# content = conv["content"] +# except: +# role = conv["from"] +# content = conv["value"] + +# role = roles.get(role, role) + +# conv = [{"role" : role, "content" : content}] +# # First is bos token we don't need here +# encode_id = tokenizer.apply_chat_template(conv)[1:] +# input_id += encode_id +# if role in ["user", "system"]: +# target += [IGNORE_INDEX] * len(encode_id) +# else: +# target += encode_id + + + +# assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" +# for idx, encode_id in enumerate(input_id): +# if encode_id in unmask_tokens_idx: +# target[idx] = encode_id +# if encode_id == image_token_index: +# input_id[idx] = IMAGE_TOKEN_INDEX +# input_ids.append(input_id) +# targets.append(target) +# input_ids = torch.tensor(input_ids, dtype=torch.long) +# targets = torch.tensor(targets, dtype=torch.long) + +# return dict( +# input_ids=input_ids, # tensor(bs x seq_len) +# labels=targets, # tensor(bs x seq_len) +# ) + + +def preprocess_llama_2( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 + + # Mask targets + sep = "[/INST] " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_v1( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.version == "llama3": + return preprocess_llama_3(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "phi3": + return preprocess_phi_3(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "qwen": + return preprocess_qwen(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: + return preprocess_llama_2(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_image=has_image) + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{conversation_lib.default_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + # tokenize conversations + def get_tokenize_len(prompts): + return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] + + if has_image: + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + else: + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + if has_image: + tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) + else: + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + +def read_jsonl(path): + list_data_dict = [] + + with open(path, "r") as file: + for line in file: + d = json.loads(line) + list_data_dict.append(d) + return list_data_dict + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments): + super(LazySupervisedDataset, self).__init__() + if "jsonl" in data_path: + list_data_dict = read_jsonl(data_path) + else: + list_data_dict = json.load(open(data_path, "r")) + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.data_args = data_args + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + img_tokens = 128 if 'image' in sample else 0 + length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) + return length_list + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) + cur_len = cur_len if 'image' in sample else -cur_len + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + if 'image' in sources[0]: + image_file = self.list_data_dict[i]['image'] + image_folder = self.data_args.image_folder + processor = self.data_args.image_processor + + try: + crop_size = self.data_args.image_processor.crop_size + except: + crop_size = self.data_args.image_processor.size + + try: + image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + except Exception as e: + from icecream import ic + ic("----------------------------------") + ic("OS ERROROROROROROROROROROOR") + ic("OS ERROROROROROROROROROROOR") + ic(image_file) + ic(e) + ic("OS ERROROROROROROROROROROOR") + ic("OS ERROROROROROROROROROROOR") + ic("===================================") + return self.__getitem__(0) + + # image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + if self.data_args.image_aspect_ratio == 'pad': + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + else: + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + sources = preprocess_multimodal( + copy.deepcopy([e["conversations"] for e in sources]), + self.data_args) + else: + sources = copy.deepcopy([e["conversations"] for e in sources]) + data_dict = preprocess( + sources, + self.tokenizer, + has_image=('image' in self.list_data_dict[i])) + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], + labels=data_dict["labels"][0]) + + # image exist in the data + if 'image' in self.list_data_dict[i]: + data_dict['image'] = image + elif self.data_args.is_multimodal: + # image does not exist in the data, but the model is multimodal + try: + crop_size = self.data_args.image_processor.crop_size + except: + crop_size = self.data_args.image_processor.size + data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) + return data_dict + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] + for key in ("input_ids", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, + batch_first=True, + padding_value=IGNORE_INDEX) + input_ids = input_ids[:, :self.tokenizer.model_max_length] + labels = labels[:, :self.tokenizer.model_max_length] + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + if 'image' in instances[0]: + images = [instance['image'] for instance in instances] + if all(x is not None and x.shape == images[0].shape for x in images): + batch['images'] = torch.stack(images) + else: + batch['images'] = images + + return batch + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, + data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=data_args.data_path, + data_args=data_args) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator) + +def add_special_tokens( + special_tokens: List, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_tokens(special_tokens, special_tokens=True) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + +def train(attn_implementation=None): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + + bnb_model_from_pretrained_args = {} + if training_args.bits in [4, 8]: + from transformers import BitsAndBytesConfig + bnb_model_from_pretrained_args.update(dict( + device_map={"": training_args.device}, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + quantization_config=BitsAndBytesConfig( + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + llm_int8_skip_modules=["mm_projector"], + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=training_args.double_quant, + bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} + ) + )) + + if model_args.vision_tower is not None: + if 'phi' in model_args.model_name_or_path.lower(): + model = LlavaPhi3ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + elif 'qwen2' in model_args.model_name_or_path.lower(): + model = LlavaQwenForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + model = LlavaLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + if 'phi' in model_args.model_name_or_path.lower(): + model = transformers.Phi3ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + elif 'qwen2' in model_args.model_name_or_path.lower(): + model = transformers.Qwen2ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + model = transformers.LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + model.config.use_cache = False + + if model_args.freeze_backbone: + model.model.requires_grad_(False) + + if training_args.bits in [4, 8]: + from peft import prepare_model_for_kbit_training + model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) + + if training_args.gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if training_args.lora_enable: + from peft import LoraConfig, get_peft_model + lora_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=find_all_linear_names(model), + lora_dropout=training_args.lora_dropout, + bias=training_args.lora_bias, + task_type="CAUSAL_LM", + ) + if training_args.bits == 16: + if training_args.bf16: + model.to(torch.bfloat16) + if training_args.fp16: + model.to(torch.float16) + rank0_print("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + + if "qwen" in model_args.model_name_or_path.lower(): + tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right") + else: + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.unk_token + if tokenizer.pad_token_id is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=""), + tokenizer=tokenizer, + model=model, + ) + + if model_args.version in conversation_lib.conv_templates: + conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] + else: + conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] + + if model_args.vision_tower is not None: + vision_tower = model.get_vision_tower() + + if vision_tower is None: + model.get_model().initialize_vision_modules( + model_args=model_args, + fsdp=training_args.fsdp + ) + vision_tower = model.get_vision_tower() + + if not vision_tower.is_loaded: + vision_tower.load_model() + vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + data_args.image_processor = vision_tower.image_processor + data_args.is_multimodal = True + + model.config.image_grid_pinpoints = [[336,672], [672,336], [672,672], [1008,336], [336,1008]] + model.config.image_aspect_ratio = data_args.image_aspect_ratio + model.config.tokenizer_padding_side = tokenizer.padding_side + model.config.tokenizer_model_max_length = tokenizer.model_max_length + + model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter + if model_args.tune_mm_mlp_adapter: + model.requires_grad_(False) + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = True + + model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter + if training_args.freeze_mm_mlp_adapter: + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = False + + if training_args.bits in [4, 8]: + model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) + + if model_args.unfreeze_mm_vision_tower: + model.requires_grad_(False) + model.get_model().vision_tower.requires_grad_(True) + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = True + else: + model.get_model().vision_tower.requires_grad_(False) + + if model_args.unfreeze_whole_model: + model.requires_grad_(True) + + model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_projector_lr = training_args.mm_projector_lr + training_args.use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token + model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) + model.config.use_s2 = model_args.use_s2 + model.config.s2_scales = model_args.s2_scales + + if hasattr(model.config, "task_token_format"): + if "emb" in model.config.task_token_format and model.config.num_task_tokens > 0: + requires_grad_value = not model_args.freeze_task_token + if "gen" in model.config.aux_mode: + model.get_model().special_gen_tokens.requires_grad_(requires_grad_value) + if "seg" in model.config.aux_mode: + model.get_model().special_seg_tokens.requires_grad_(requires_grad_value) + if "depth" in model.config.aux_mode: + model.get_model().special_depth_tokens.requires_grad_(requires_grad_value) + + # from safetensors import safe_open + # checkpoint_path = model_args.model_name_or_path + # safetensors_paths = [os.path.join(checkpoint_path, f) for f in os.listdir(checkpoint_path) if f.endswith('.safetensors')] + + # state_dict = {} + # for path in safetensors_paths: + # with safe_open(path, framework="pt", device="cpu") as f: + # for key in f.keys(): + # state_dict[key] = f.get_tensor(key) + # # model.load_state_dict(state_dict, strict=False) + # from tqdm import tqdm + # for param_name, param in tqdm(model.named_parameters()): + # if "vision_tower" in param_name: + # checkpoint_value = state_dict[param_name] + # model_value = model.state_dict()[param_name] + + # from icecream import ic + # if not torch.equal(checkpoint_value.to(device=model_value.device, dtype=model_value.dtype), model_value): + # ic(param_name, checkpoint_value.mean(), model_value.mean()) + + import torch.distributed as dist + from icecream import ic + if dist.get_rank() == 0: + for n, p in model.named_parameters(): + if p.requires_grad: + ic(n) + + if training_args.bits in [4, 8]: + from peft.tuners.lora import LoraLayer + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if training_args.bf16: + module = module.to(torch.bfloat16) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if training_args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + data_module = make_supervised_data_module(tokenizer=tokenizer, + data_args=data_args) + trainer = LLaVATrainer(model=model, + tokenizer=tokenizer, + args=training_args, + **data_module) + + print('starting training...', local_rank) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + model.config.use_cache = True + + if training_args.lora_enable: + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), training_args.lora_bias + ) + non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( + model.named_parameters() + ) + if training_args.local_rank == 0 or training_args.local_rank == -1: + model.config.save_pretrained(training_args.output_dir) + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) + else: + safe_save_model_for_hf_trainer(trainer=trainer, + output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/ola_vlm/train/train_mem.py b/ola_vlm/train/train_mem.py new file mode 100644 index 0000000000000000000000000000000000000000..2b91b9cf8dcf998c06e966d68dac957bf6eafb5e --- /dev/null +++ b/ola_vlm/train/train_mem.py @@ -0,0 +1,7 @@ +from ola_vlm.train.train import train + +if __name__ == "__main__": + # try: + # train(attn_implementation="flash_attention_2") + # except: + train(attn_implementation="eager") diff --git a/ola_vlm/train/train_xformers.py b/ola_vlm/train/train_xformers.py new file mode 100644 index 0000000000000000000000000000000000000000..42e8ca70c03ad1a2ab6e203e8dbb4166ad5af22f --- /dev/null +++ b/ola_vlm/train/train_xformers.py @@ -0,0 +1,13 @@ +# Make it more memory efficient by monkey patching the LLaMA model with xformers attention. + +# Need to call this before importing transformers. +from ola_vlm.train.llama_xformers_attn_monkey_patch import ( + replace_llama_attn_with_xformers_attn, +) + +replace_llama_attn_with_xformers_attn() + +from ola_vlm.train.train import train + +if __name__ == "__main__": + train() diff --git a/ola_vlm/utils.py b/ola_vlm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ba4d2439f1ea6088831618982de803863cc5543 --- /dev/null +++ b/ola_vlm/utils.py @@ -0,0 +1,132 @@ +import datetime +import logging +import logging.handlers +import os +import sys + +import requests +import torch.distributed as dist +from ola_vlm.constants import LOGDIR + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." + +handler = None + +def rank0_print(*args): + if dist.is_initialized(): + if dist.get_rank() == 0: + print(f"Rank {dist.get_rank()}: ", *args) + else: + print(*args) + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when='D', utc=True, encoding='UTF-8') + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = '' + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = '' + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == '\n': + self.logger.log(self.log_level, line.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != '': + self.logger.log(self.log_level, self.linebuf.rstrip()) + self.linebuf = '' + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def violates_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + url = "https://api.openai.com/v1/moderations" + headers = {"Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} + text = text.replace("\n", "") + data = "{" + '"input": ' + f'"{text}"' + "}" + data = data.encode("utf-8") + try: + ret = requests.post(url, headers=headers, data=data, timeout=5) + flagged = ret.json()["results"][0]["flagged"] + except requests.exceptions.RequestException as e: + flagged = False + except KeyError as e: + flagged = False + + return flagged + + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" diff --git a/requirements.txt b/requirements.txt index cfc5b09a68217c6eba8d711a8c995c765049d339..8c515d51efcd10dea82e48b558fa60b46d4b9b2f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,26 @@ -huggingface_hub==0.25.2 \ No newline at end of file +torch==2.2.0 +torchvision==0.17.0 +tokenizers==0.19.1 +sentencepiece==0.1.99 +shortuuid +peft +bitsandbytes +open_clip_torch +diffdist +pydantic==2.8.2 +pydantic-core==2.20.1 +markdown2[all] +numpy==1.26.2 +gradio==4.16.0 +gradio_client==0.8.1 +huggingface_hub +requests +httpx==0.24.0 +uvicorn +fastapi==0.111.0 +einops==0.6.1 +einops-exts==0.0.4 +timm==1.0.8 +diffusers===0.27.2 +protobuf +accelerate==0.27.2 \ No newline at end of file