import argparse import torch from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from ola_vlm.conversation import conv_templates from ola_vlm.model.builder import load_pretrained_model from ola_vlm.utils import disable_torch_init from ola_vlm.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path from ola_vlm.model.aux_heads.sam_utils.build_sam import sam_model_registry from ola_vlm.model.aux_heads.sam_utils.automatic_mask_generator import SamAutomaticMaskGenerator from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead, OneFormerSegHead, OneFormerTaskTokenSegHead from ola_vlm.model.aux_heads.depth_anything_v2.dpt import DepthAnythingV2 from transformers import OneFormerProcessor from diffusers import ( DPMSolverMultistepScheduler, StableUnCLIPImg2ImgPipeline, ) from PIL import Image import json import os from tqdm import tqdm from icecream import ic import warnings warnings.filterwarnings("ignore") import random import numpy as np from analyze.analyze_utils import prepare_coco import math def split_list(lst, n): """Split a list into n (roughly) equal-sized chunks""" chunk_size = math.ceil(len(lst) / n) # integer division return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] def get_chunk(lst, n, k): chunks = split_list(lst, n) return chunks[k] def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def load_image(image_file): image = Image.open(image_file).convert('RGB') return image import glob def list_image_files(directory): image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.gif', '*.bmp', '*.tiff'] image_files = [] for extension in image_extensions: image_files.extend(glob.glob(os.path.join(directory, extension))) return image_files def get_gen_feats(pipe, image): with torch.no_grad(): clip_ims = pipe.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") feat = pipe.image_encoder(clip_ims).image_embeds return feat def get_dav2_feats(dav2, image): image = image.resize((336, 336)) image = np.array(image) with torch.no_grad(): feat = dav2.infer_image(image, is_dsg=True) return feat[-1][0] def get_seg_feats(mask_generator, oneformer, oneformer_processor, seg_teacher, image): if seg_teacher == "oneformer": img = image.resize((768, 768)) inputs = oneformer_processor(img, ["panoptic"], return_tensors="pt") inputs["pixel_values"] = inputs["pixel_values"].to("cuda") with torch.no_grad(): feats = oneformer.forward_features(**inputs) else: img = np.array(image) with torch.no_grad(): mask_generator.predictor.set_image(img) feats = mask_generator.predictor.features mask_generator.predictor.reset_image() return feats def predict(args): mode = args.mode name = args.model_path.split("/")[-1] os.makedirs(f"plots/probe_scores/{name}/", exist_ok=True) if "cambrian" in name: from ola_vlm.cambrian.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from ola_vlm.cambrian.conversation import conv_templates, SeparatorStyle from ola_vlm.cambrian.model.builder import load_pretrained_model from ola_vlm.cambrian.utils import disable_torch_init from ola_vlm.cambrian.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria disable_torch_init() model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) if 'llama-2' in model_name.lower(): conv_mode = "cambrian_llama_2" elif "v1" in model_name.lower(): conv_mode = "cambrian_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" else: conv_mode = "cambrian_v0" else: from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from ola_vlm.conversation import conv_templates from ola_vlm.model.builder import load_pretrained_model from ola_vlm.utils import disable_torch_init from ola_vlm.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path disable_torch_init() model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) if "mistral" in model_name.lower(): conv_mode = "mistral_instruct" elif "v1.6-34b" in model_name.lower(): conv_mode = "chatml_direct" elif "llama3" in model_name.lower(): conv_mode = "llava_llama_3" elif "qwen" in model_name.lower(): conv_mode = "llava_qwen" elif "v1" in model_name.lower(): conv_mode = "llava_v1" elif "phi" in model_name.lower(): conv_mode = "llava_phi_3" images, prompts, answers = prepare_coco(args.json_file) images = get_chunk(images, args.num_chunks, args.chunk_idx) prompts = get_chunk(prompts, args.num_chunks, args.chunk_idx) answers = get_chunk(answers, args.num_chunks, args.chunk_idx) if mode == "gen": pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"playground/jiteshjain_sherlock/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16") pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to("cuda") elif mode == "seg": oneformer_processor, oneformer, mask_generator = None, None, None seg_teacher = model.config.image_seg.get("seg_teacher", "sam") if seg_teacher == "sam": sam = sam_model_registry["vit_l"](checkpoint="/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large") sam = sam.to("cuda") mask_generator = SamAutomaticMaskGenerator(sam.float()) else: oneformer_processor = OneFormerProcessor.from_pretrained("/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large") oneformer = OneFormerHead.from_pretrained("/mnt/projects4jw/jiteshjain_sherlock/oneformer_coco_swin_large") oneformer = oneformer.to("cuda") elif mode == "depth": dav2_cfg = {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]} dav2_backbone = DepthAnythingV2(**dav2_cfg) dav2_backbone.load_state_dict(torch.load("/mnt/projects4jw/jiteshjain_sherlock/depth_anything_v2_vitl.pth", map_location='cpu')) dav2_backbone = dav2_backbone.to("cuda") set_seed(42) if mode == "gen": try: layers = model.config.image_gen["layer_indices"] except: layers = [i+1 for i in range(32)] elif mode == "depth": try: layers = model.config.image_depth["layer_indices"] except: layers = [i+1 for i in range(32)] elif mode == "seg": try: layers = model.config.image_seg["layer_indices"] except: layers = [i+1 for i in range(32)] os.makedirs(f"plots/probe_scores/{name}/{mode}/", exist_ok=True) if os.path.exists(f"plots/probe_scores/{name}/{mode}/{args.num_chunks}_{args.chunk_idx}.json"): with open(f"plots/probe_scores/{name}/{mode}/{args.num_chunks}_{args.chunk_idx}.json", 'r') as f: diff_dict = json.load(f) else: diff_dict = {} i = 0 from tqdm import tqdm for fname, prompt, answer in tqdm(zip(images, prompts, answers), total=len(prompts)): # if fname.split("/")[-1] in diff_dict.keys(): # continue conv = conv_templates[conv_mode].copy() image = load_image(fname) image = image.resize((640, 640)) image_size = image.size image_tensor = process_images([image], image_processor, model.config) if type(image_tensor) is list: image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] else: image_tensor = image_tensor.to(model.device, dtype=torch.float16) inp = prompt if image is not None: if model.config.mm_use_im_start_end: inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp else: inp = DEFAULT_IMAGE_TOKEN + '\n' + inp conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) with torch.inference_mode(): out = model.get_visual_interpretations( input_ids, images=image_tensor, image_sizes=[image_size], ) if mode == "gen": embeds = out.image_embs feats = get_gen_feats(pipe, image) elif mode == "depth": embeds = out.depth_embs embeds = [emb[0][0] for emb in embeds] feats = get_dav2_feats(dav2_backbone, image) elif mode == "seg": embeds = out.seg_embs feats = get_seg_feats(mask_generator, oneformer, oneformer_processor, seg_teacher, image) layer_diff = {} for i, emb in enumerate(embeds): emb = emb.to("cuda") layer_diff[layers[i]] = torch.nn.CosineEmbeddingLoss(reduction="mean")( emb.reshape(1, -1).float(), feats.reshape(1, -1).float(), torch.ones(len(emb)).to(feats.device) ).cpu().item() from icecream import ic ic(layer_diff[layers[i]]) diff_dict[fname.split("/")[-1]] = layer_diff if i % 200 == 0: # Save progress intermittently with open(f"plots/probe_scores/{name}/{mode}/{args.num_chunks}_{args.chunk_idx}.json", 'w') as f: json.dump(diff_dict, f, indent=2) i += 1 with open(f"plots/probe_scores/{name}/{mode}/{args.num_chunks}_{args.chunk_idx}.json", 'w') as f: json.dump(diff_dict, f, indent=2) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="/mnt/projects4jw/jiteshjain_sherlock/llava-v1.5-7b") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--json-file", type=str, default="/mnt/projects4jw/jiteshjain_sherlock/datasets/coco/annotations/captions_val2017.json") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--max-new-tokens", type=int, default=10) parser.add_argument("--load-8bit", action="store_true") parser.add_argument("--load-4bit", action="store_true") parser.add_argument("--mode", type=str, default="gen") parser.add_argument("--num-chunks", type=int, default=1) parser.add_argument("--chunk-idx", type=int, default=0) args = parser.parse_args() predict(args)