# Authors: Hui Ren (rhfeiyang.github.io)
import os
import sys
import numpy as np
from PIL import Image
import pickle
sys.path.append(os.path.join(os.path.dirname(__file__), "../../"))
from custom_datasets.sam import SamDataset
from utils.art_filter import Art_filter
import torch
from matplotlib import pyplot as plt
import math
import argparse
import socket
import time
from tqdm import tqdm

def parse_args():
    parser = argparse.ArgumentParser(description="Filter the sam dataset")
    parser.add_argument("--check", action="store_true", help="Check the complete")
    parser.add_argument("--mode", default="clip_logit",  choices=["clip_logit_update","clip_logit", "clip_filt", "caption_filt", "gather_result","caption_flit_append"])
    parser.add_argument("--start_idx", default=0, type=int, help="Start index")
    parser.add_argument("--end_idx", default=9e10, type=int, help="Start index")
    args = parser.parse_args()
    return args
@torch.no_grad()
def main(args):
    filter = Art_filter()
    if args.mode == "caption_filt" or args.mode == "gather_result":
        filter.clip_filter = None
        torch.cuda.empty_cache()

    caption_folder_path = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/SAM/subset/captions"
    image_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/nfs-data/sam/images"
    id_dict_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/sam_ids/8.16/id_dict"
    filt_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result"
    def collate_fn(examples):
        # {"image": image, "id":id}
        ret = {}
        if "image" in examples[0]:
            pixel_values = [example["image"] for example in examples]
            ret["images"] = pixel_values
        if "text" in examples[0]:
            prompts = [example["text"] for example in examples]
            ret["text"] = prompts
        id = [example["id"] for example in examples]
        ret["ids"] = id
        return ret
    error_files=[]
    val_set = ["sa_000000"]
    result_check_set = ["sa_000020"]
    all_remain_ids=[]
    all_remain_ids_train=[]
    all_remain_ids_val=[]
    all_filtered_id_num = 0
    remain_feat_num = 0
    remain_caption_num = 0
    filter_feat_num = 0
    filter_caption_num = 0
    for idx,file in tqdm(enumerate(sorted(os.listdir(id_dict_dir)))):
        if idx < args.start_idx or idx >= args.end_idx:
            continue
        if file.endswith(".pickle") and not file.startswith("all"):
            print("=====================================")
            print(file,flush=True)
            save_dir = os.path.join(filt_dir, file.replace("_id_dict.pickle", ""))
            if not os.path.exists(save_dir):
                os.makedirs(save_dir, exist_ok=True)
            id_dict_file = os.path.join(id_dict_dir, file)
            with open(id_dict_file, 'rb') as f:
                id_dict = pickle.load(f)
            ids = list(id_dict.keys())
            dataset = SamDataset(image_folder_path, caption_folder_path, id_file=ids, id_dict_file=id_dict_file)
            # dataset = SamDataset(image_folder_path, caption_folder_path, id_file=[10061410, 10076945, 10310013,1042012, 4487809, 4541052], id_dict_file="/vision-nfs/torralba/scratch/jomat/sam_dataset/images/id_dict/all_id_dict.pickle")
            dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
            clip_logits = None
            clip_logits_file = os.path.join(save_dir, "clip_logits_result.pickle")
            clip_filt_file = os.path.join(save_dir, "clip_filt_result.pickle")
            caption_filt_file = os.path.join(save_dir, "caption_filt_result.pickle")

            if args.mode == "clip_feat":
                compute_new = False
                clip_logits = {}
                if os.path.exists(clip_logits_file):
                    with open(clip_logits_file, 'rb') as f:
                        clip_logits = pickle.load(f)
                    if "image_features" not in clip_logits:
                        compute_new = True
                else:
                    compute_new=True
                if compute_new:
                    if clip_logits == '':
                        clip_logits = {}
                    print(f"compute clip_feat {file}",flush=True)
                    clip_feature_ret = filter.clip_feature(dataloader)
                    clip_logits["image_features"] = clip_feature_ret["clip_features"]
                    if "ids" in clip_logits:
                        assert clip_feature_ret["ids"] == clip_logits["ids"]
                    else:
                        clip_logits["ids"] = clip_feature_ret["ids"]

                    with open(clip_logits_file, 'wb') as f:
                        pickle.dump(clip_logits, f)
                    print(f"clip_feat_result saved to {clip_logits_file}",flush=True)
                else:
                    print(f"skip {clip_logits_file}",flush=True)

            if args.mode == "clip_logit":
            # if clip_logit:
                if os.path.exists(clip_logits_file):
                    try:
                        with open(clip_logits_file, 'rb') as f:
                            clip_logits = pickle.load(f)
                    except:
                        continue
                    skip = True
                    if args.check and clip_logits=="":
                        skip = False

                else:
                    skip = False
                # skip = False
                if not skip:
                    # os.makedirs(os.path.join(save_dir, "tmp"), exist_ok=True)
                    with open(clip_logits_file, 'wb') as f:
                        pickle.dump("", f)
                    try:
                        clip_logits = filter.clip_logit(dataloader)
                    except:
                        print(f"Error in clip_logit {file}",flush=True)
                        continue
                    with open(clip_logits_file, 'wb') as f:
                        pickle.dump(clip_logits, f)
                    print(f"clip_logits_result saved to {clip_logits_file}",flush=True)
                else:
                    print(f"skip {clip_logits_file}",flush=True)

            if args.mode == "clip_logit_update":
                if os.path.exists(clip_logits_file):
                    with open(clip_logits_file, 'rb') as f:
                        clip_logits = pickle.load(f)
                else:
                    print(f"{clip_logits_file} not exist",flush=True)
                    continue
                if clip_logits == "":
                    print(f"skip {clip_logits_file}",flush=True)
                    continue
                ret = filter.clip_logit_by_feat(clip_logits["clip_features"])
                # assert (clip_logits["clip_logits"] - ret["clip_logits"]).abs().max() < 0.01
                clip_logits["clip_logits"] = ret["clip_logits"]
                clip_logits["text"] = ret["text"]
                with open(clip_logits_file, 'wb') as f:
                    pickle.dump(clip_logits, f)


            if args.mode == "clip_filt":
                # if os.path.exists(clip_filt_file):
                #     with open(clip_filt_file, 'rb') as f:
                #         ret = pickle.load(f)
                # else:

                if clip_logits is None:
                    try:
                        with open(clip_logits_file, 'rb') as f:
                            clip_logits = pickle.load(f)
                    except:
                        print(f"Error in loading {clip_logits_file}",flush=True)
                        error_files.append(clip_logits_file)
                        continue
                    if clip_logits == "":
                        print(f"skip {clip_logits_file}",flush=True)
                        error_files.append(clip_logits_file)
                        continue
                clip_filt_result = filter.clip_filt(clip_logits)
                with open(clip_filt_file, 'wb') as f:
                    pickle.dump(clip_filt_result, f)
                print(f"clip_filt_result saved to {clip_filt_file}",flush=True)

            if args.mode == "caption_filt":
                if os.path.exists(caption_filt_file):
                    try:
                        with open(caption_filt_file, 'rb') as f:
                            ret = pickle.load(f)
                    except:
                        continue
                    skip = True
                    if args.check and ret=="":
                        skip = False
                        # os.remove(caption_filt_file)
                        print(f"empty {caption_filt_file}",flush=True)
                        # skip = True
                else:
                    skip = False
                if not skip:
                    with open(caption_filt_file, 'wb') as f:
                        pickle.dump("", f)
                    # try:
                    ret = filter.caption_filt(dataloader)
                    # except:
                    #     print(f"Error in filtering {file}",flush=True)
                    #     continue
                    with open(caption_filt_file, 'wb') as f:
                        pickle.dump(ret, f)
                    print(f"caption_filt_result saved to {caption_filt_file}",flush=True)
                else:
                    print(f"skip {caption_filt_file}",flush=True)

            if args.mode == "caption_flit_append":
                if not os.path.exists(caption_filt_file):
                    print(f"{caption_filt_file} not exist",flush=True)
                    continue
                with open(caption_filt_file, 'rb') as f:
                    old_caption_filt_result = pickle.load(f)
                skip = True
                for i in filter.caption_filter.filter_prompts:
                    if i not in old_caption_filt_result["filter_prompts"]:
                        skip = False
                        break
                if skip:
                    print(f"skip {caption_filt_file}",flush=True)
                    continue
                old_remain_ids = old_caption_filt_result["remain_ids"]
                new_dataset = SamDataset(image_folder_path, caption_folder_path, id_file=old_remain_ids, id_dict_file=id_dict_file)
                new_dataloader = torch.utils.data.DataLoader(new_dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
                ret = filter.caption_filt(new_dataloader)
                old_caption_filt_result["remain_ids"] = ret["remain_ids"]
                old_caption_filt_result["filtered_ids"].extend(ret["filtered_ids"])
                new_filter_count = ret["filter_count"].copy()
                for i in range(len(old_caption_filt_result["filter_count"])):
                    new_filter_count[i] += old_caption_filt_result["filter_count"][i]

                old_caption_filt_result["filter_count"] = new_filter_count
                old_caption_filt_result["filter_prompts"] = ret["filter_prompts"]
                with open(caption_filt_file, 'wb') as f:
                    pickle.dump(old_caption_filt_result, f)



            if args.mode == "gather_result":
                with open(clip_filt_file, 'rb') as f:
                    clip_filt_result = pickle.load(f)
                with open(caption_filt_file, 'rb') as f:
                    caption_filt_result = pickle.load(f)
                caption_filtered_ids = [i[0] for i in caption_filt_result["filtered_ids"]]
                all_filtered_id_num += len(set(clip_filt_result["filtered_ids"]) | set(caption_filtered_ids) )

                remain_feat_num += len(clip_filt_result["remain_ids"])
                remain_caption_num += len(caption_filt_result["remain_ids"])
                filter_feat_num += len(clip_filt_result["filtered_ids"])
                filter_caption_num += len(caption_filtered_ids)

                remain_ids = set(clip_filt_result["remain_ids"]) & set(caption_filt_result["remain_ids"])
                remain_ids = list(remain_ids)
                remain_ids.sort()
                # with open(os.path.join(save_dir, "remain_ids.pickle"), 'wb') as f:
                #     pickle.dump(remain_ids, f)
                # print(f"remain_ids saved to {save_dir}/remain_ids.pickle",flush=True)
                all_remain_ids.extend(remain_ids)
                if file.replace("_id_dict.pickle","") in val_set:
                    all_remain_ids_val.extend(remain_ids)
                else:
                    all_remain_ids_train.extend(remain_ids)
    if args.mode == "gather_result":
        print(f"filtered ids: {all_filtered_id_num}",flush=True)
        print(f"remain feat num: {remain_feat_num}",flush=True)
        print(f"remain caption num: {remain_caption_num}",flush=True)
        print(f"filter feat num: {filter_feat_num}",flush=True)
        print(f"filter caption num: {filter_caption_num}",flush=True)
        all_remain_ids.sort()
        with open(os.path.join(filt_dir, "all_remain_ids.pickle"), 'wb') as f:
            pickle.dump(all_remain_ids, f)
        with open(os.path.join(filt_dir, "all_remain_ids_train.pickle"), 'wb') as f:
            pickle.dump(all_remain_ids_train, f)
        with open(os.path.join(filt_dir, "all_remain_ids_val.pickle"), 'wb') as f:
            pickle.dump(all_remain_ids_val, f)

        print(f"all_remain_ids saved to {filt_dir}/all_remain_ids.pickle",flush=True)
        print(f"all_remain_ids_train saved to {filt_dir}/all_remain_ids_train.pickle",flush=True)
        print(f"all_remain_ids_val saved to {filt_dir}/all_remain_ids_val.pickle",flush=True)

    print("finished",flush=True)
    for file in error_files:
        # os.remove(file)
        print(file,flush=True)

if __name__ == "__main__":
    args = parse_args()

    log_file = "sam_filt"
    idx=0
    hostname = socket.gethostname()
    now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    while os.path.exists(f"{log_file}_{hostname}_check{args.check}_{now_time}_{idx}.log"):
        idx+=1

    main(args)
    # clip_logits_analysis()