# Authors: Hui Ren (rhfeiyang.github.io)
import os.path
import sys
from typing import Any, Callable, List, Optional, Tuple

import tqdm
from PIL import Image

from torch.utils.data import Dataset
import pickle
from torchvision import transforms
# import torch
# import torchvision
# import re


class SamDataset(Dataset):
    def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "data/sam/clip_filtered_ids.pickle",id_dict_file:str =None , transforms: Optional[Callable] = None,
                 resolution=None,
                 get_img=True,
                 get_cap=True,):
        if id_dict_file is not None:
            with open(id_dict_file, 'rb') as f:
                print(f"Loading id_dict from {id_dict_file}", flush=True)
                self.id_dict = pickle.load(f)
                print(f"Loaded id_dict from {id_dict_file}", flush=True)
        else:
            self.id_dict = None
        if isinstance(id_file, list):
            self.ids = id_file
        elif isinstance(id_file, str):
            with open(id_file, 'rb') as f:
                print(f"Loading ids from {id_file}", flush=True)
                self.ids = pickle.load(f)
                print(f"Loaded ids from {id_file}", flush=True)
        self.resolution = resolution
        self.ori_image_folder_path = image_folder_path
        if self.resolution is not None:
            if os.path.exists("/var/jomat/datasets/"):
                # self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}"
                self.image_folder_path = f"{image_folder_path}_{resolution}"
            else:
                self.image_folder_path = f"{image_folder_path}_{resolution}"
            os.makedirs(self.image_folder_path, exist_ok=True)
        else:
            self.image_folder_path = image_folder_path
        self.caption_folder_path = caption_folder_path
        self.transforms = transforms
        self.column_names = ["image", "text"]
        self.get_img = get_img
        self.get_cap = get_cap

    def __len__(self):
        # return 100
        return len(self.ids)

    def __getitem__(self, index: int):
        id = self.ids[index]
        ret={"id":id}
        try:
            # if index == 1:
            #     raise Exception("test")
            if self.get_img:
                image = self._load_image(id)
                ret["image"]=image
            if self.get_cap:
                target = self._load_caption(id)
                ret["text"] = [target]
            if self.transforms is not None:
                ret = self.transforms(ret)
            return ret
        except Exception as e:
            raise e
            print(f"Error loading image and caption for id {id}, error: {e}, redirecting to index 0", flush=True)
            ret = self[0]
            return ret

    def define_resolution(self, resolution: int):
        self.resolution = resolution
        if os.path.exists("/var/jomat/datasets/"):
            self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}"
            # self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}"
        else:
            self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}"
        print(f"SamDataset resolution defined to {resolution}, new image folder path: {self.image_folder_path}")
    def _load_image(self, id: int) -> Image.Image:
        if self.id_dict is not None:
            subfolder = self.id_dict[id]
            image_path = f"{self.image_folder_path}/{subfolder}/sa_{id}.jpg"
        else:
            image_path = f"{self.image_folder_path}/sa_{id}.jpg"

        try:
            with open(image_path, 'rb') as f:
                img = Image.open(f).convert("RGB")
            # return img
        except:
            # load original image
            if self.id_dict is not None:
                subfolder = self.id_dict[id]
                ori_image_path = f"{self.ori_image_folder_path}/{subfolder}/sa_{id}.jpg"
            else:
                ori_image_path = f"{self.ori_image_folder_path}/sa_{id}.jpg"
            assert os.path.exists(ori_image_path)
            with open(ori_image_path, 'rb') as f:
                img = Image.open(f).convert("RGB")
            # resize image keep aspect ratio
            if self.resolution is not None:
                img = transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BICUBIC)(img)
            # write image
            os.makedirs(os.path.dirname(image_path), exist_ok=True)
            img.save(image_path)

        return img

    
    def _load_caption(self, id: int):
        caption_path = f"{self.caption_folder_path}/sa_{id}.txt"
        if not os.path.exists(caption_path):
            return None
        try:
            with open(caption_path, 'r', encoding="utf-8") as f:
                content = f.read()
        except Exception as e:
            raise e
            print(f"Error reading caption file {caption_path}, error: {e}")
            return None
        sentences = content.split('.')
        # remove empty sentences and sentences with "black and white"(too many false prediction)
        sentences = [sentence.strip() for sentence in sentences if sentence.strip() and "black and white" not in sentence]
        # join sentence
        sentences = ". ".join(sentences)
        if len(sentences) > 0 and sentences[-1] != '.':
            sentences += '.'

        return sentences
    
    def with_transform(self, transform):
        self.transforms = transform
        return self

    def subsample(self, n: int = 10000):
        if n is None or n == -1:
            return self
        ori_len = len(self)
        assert n <= ori_len
        # equal interval subsample
        ids = self.ids[::ori_len // n][:n]
        self.ids = ids
        print(f"SAM dataset subsampled from {ori_len} to {len(self)}")
        return self


if __name__ == "__main__":
    # sam_filt(caption_filt=False, clip_filt=False, clip_logit=True)
    from custom_datasets.sam_caption.mypath import MyPath
    dataset = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict"))
    dataset.get_img = False
    for i in tqdm.tqdm(dataset):
        a=i['text']