rhfeiyang's picture
Upload folder using huggingface_hub
262b155 verified
raw
history blame
4.02 kB
# Authors: Hui Ren (rhfeiyang.github.io)
import torch
import pandas as pd
import numpy as np
import os
from PIL import Image
class Caption_set(torch.utils.data.Dataset):
style_set_names=[
"andre-derain_subset1",
"andy_subset1",
"camille-corot_subset1",
"gerhard-richter_subset1",
"henri-matisse_subset1",
"katsushika-hokusai_subset1",
"klimt_subset3",
"monet_subset2",
"picasso_subset1",
"van_gogh_subset1",
]
style_set_map={f"{name}":f"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Style_captions/{name}/style_captions.csv" for name in style_set_names}
def __init__(self, prompts_path=None, set_name=None, transform=None):
assert prompts_path is not None or set_name is not None, "Either prompts_path or set_name should be provided"
if prompts_path is None:
prompts_path = self.style_set_map[set_name]
self.prompts = pd.read_csv(prompts_path, delimiter=';')
self.transform = transform
def __len__(self):
return len(self.prompts)
def __getitem__(self, idx):
ret={}
ret["id"] = idx
info = self.prompts.iloc[idx]
ret.update(info)
for k,v in ret.items():
if isinstance(v,np.int64):
ret[k] = int(v)
ret["caption"] = [ret["caption"]]
if self.transform:
ret = self.transform(ret)
return ret
def with_transform(self, transform):
self.transform = transform
return self
class HRS_caption(Caption_set):
def __init__(self, prompts_path="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv", transform=None, delimiter=','):
self.prompts = pd.read_csv(prompts_path, delimiter=delimiter)
self.transform = transform
self.caption_key = "original_prompts"
def __getitem__(self, idx):
ret={}
ret["id"] = idx
info = self.prompts.iloc[idx]
ret["caption"] = [info[self.caption_key]]
ret["seed"] = idx
if self.transform:
ret = self.transform(ret)
return ret
class Laion_pop(torch.utils.data.Dataset):
def __init__(self, anno_file="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/custom_datasets/laion_pop500.csv",image_root="/vision-nfs/torralba/scratch/jomat/sam_dataset/laion_pop",transform=None):
self.transform = transform
self.info = pd.read_csv(anno_file, delimiter=";")
self.caption_key = "caption"
self.image_root = image_root
self.get_img=True
self.get_caption=True
def __len__(self):
return len(self.info)
# def subsample(self, num:int):
# self.data = self.data.select(range(num))
# return self
def load_image(self, key):
image_path = os.path.join(self.image_root, f"{key:09}.jpg")
with open(image_path, "rb") as f:
image = Image.open(f).convert("RGB")
return image
def __getitem__(self, idx):
info = self.info.iloc[idx]
ret = {}
key = info["key"]
ret["id"] = key
if self.get_caption:
ret["caption"] = [info[self.caption_key]]
ret["seed"] = int(key)
if self.get_img:
ret["image"] = self.load_image(key)
if self.transform:
ret = self.transform(ret)
return ret
def with_transform(self, transform):
self.transform = transform
return self
def subset(self, ids:list):
self.info = self.info[self.info["key"].isin(ids)]
return self
if __name__ == "__main__":
dataset = Caption_set("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv")
dataset[0]