Spaces:
Running
Running
import webdataset as wds | |
from pathlib import Path | |
import pandas as pd | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
from torch.utils.data import Dataset, DataLoader | |
from utils.image_processing import CenterCrop | |
from tqdm import tqdm | |
import os | |
tqdm.pandas() | |
print("Loading dinov2") | |
augmentation_dinov2 = transforms.Compose( | |
[ | |
CenterCrop(ratio="1:1"), | |
transforms.Resize(336, interpolation=transforms.InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
] | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg") | |
model.eval() | |
model.to(device) | |
print(f"Model loaded on {device}") | |
class YFCCDataset(Dataset): | |
def __init__(self, csv_path, images_root): | |
self.df = pd.read_csv(csv_path, sep="\t") | |
self.df = self.df[self.df["latitude"].notna() & self.df["longitude"].notna()] | |
self.images_root = Path(images_root) | |
# Create image paths and check existence | |
print("Checking image existence...") | |
self.df["image_path"] = self.df["hash"].progress_apply( | |
lambda x: self.images_root / x[:3] / x[3:6] / f"{x}.jpg" | |
) | |
def __len__(self): | |
return len(self.df) | |
def __getitem__(self, idx): | |
row = self.df.iloc[idx] | |
image_path = row["image_path"] | |
if not image_path.exists(): | |
print(f"Image {image_path} does not exist") | |
return None | |
# Read the JPEG file directly as bytes | |
with open(image_path, "rb") as f: | |
jpg_data = f.read() | |
image = Image.open(image_path).convert("RGB") | |
image = augmentation_dinov2(image) | |
# Convert metadata to dict and ensure all values are JSON serializable | |
metadata = row.to_dict() | |
del metadata["image_path"] | |
return { | |
"image": image, | |
"jpg_data": jpg_data, | |
"photo_id": str(row["photo_id"]), | |
"metadata": metadata, | |
} | |
def custom_collate(batch): | |
""" | |
Custom collate function to handle dictionary items from the dataset | |
""" | |
return { | |
"image": torch.stack([item["image"] for item in batch if item is not None]), | |
"jpg_data": [item["jpg_data"] for item in batch if item is not None], | |
"photo_id": [item["photo_id"] for item in batch if item is not None], | |
"metadata": [item["metadata"] for item in batch if item is not None], | |
} | |
def process_batch(batch, model, device): | |
images = batch["image"].to(device) # No need to stack, already stacked in collate | |
with torch.no_grad(): | |
embeddings = model(images).cpu().numpy() | |
samples = [] | |
for i in range(len(batch["photo_id"])): | |
sample = { | |
"__key__": batch["photo_id"][i], | |
"jpg": batch["jpg_data"][i], | |
"dinov2_vitl14_registers.npy": embeddings[i], | |
"json": batch["metadata"][i], | |
} | |
samples.append(sample) | |
return samples | |
def main( | |
src_csv, | |
src_images, | |
dest_folder, | |
num_samples_per_tar=10000, | |
job_offset=0, | |
batch_size=32, | |
): | |
print(f"Loading dataset") | |
dataset = YFCCDataset(src_csv, src_images) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=8, | |
pin_memory=True, | |
collate_fn=custom_collate, # Add the custom collate function | |
) | |
print(f"Processing job {job_offset} with {len(dataset)} samples") | |
with wds.ShardWriter( | |
str(Path(dest_folder) / "%04d.tar"), | |
maxcount=num_samples_per_tar, | |
start_shard=10 * job_offset, | |
) as sink: | |
for batch in tqdm(dataloader): | |
samples = process_batch(batch, model, device) | |
for sample in samples: | |
sink.write(sample) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--src_csv_dir", help="pixel_input_folder") | |
parser.add_argument("--src_images_dir", help="path to source images") | |
parser.add_argument("--dest", help="path to destination web") | |
parser.add_argument( | |
"--num_samples_per_tar", | |
help="number of samples per tar", | |
type=int, | |
default=10000, | |
) | |
parser.add_argument("--job_offset", help="job offset", type=int, default=0) | |
parser.add_argument("--batch_size", help="batch size", type=int, default=256) | |
args = parser.parse_args() | |
dest = Path(args.dest) | |
dest.mkdir(exist_ok=True, parents=True) | |
main( | |
Path(args.src_csv_dir) / f"{str(args.job_offset).zfill(3)}.csv", | |
args.src_images_dir, | |
args.dest, | |
args.num_samples_per_tar, | |
args.job_offset, | |
args.batch_size, | |
) | |