Plonk / data /to_webdataset /osv_to_wds.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import argparse
import json
from collections import UserDict
from pathlib import Path
import numpy as np
import torch
import webdataset as wds
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from webdataset.autodecode import ImageHandler
from utils.image_processing import CenterCrop
print("Loading dinov2")
augmentation_dinov2 = transforms.Compose(
[
CenterCrop(ratio="1:1"),
transforms.Resize(336, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)
device = "cuda" if torch.cuda.is_available() else "cpu"
dinov2_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg")
dinov2_model.eval()
dinov2_model.to(device)
print(f"Model loaded on {device}")
def dict_collate(batch):
output_dict = {}
if isinstance(batch[0], dict):
for key in batch[0].keys():
list_key = [d[key] for d in batch]
if key != "json":
output_dict[key] = dict_collate(list_key)
else:
output_dict[key] = list_key
return output_dict
elif isinstance(batch[0], Image.Image):
return [img for img in batch]
else:
return torch.utils.data.dataloader.default_collate(batch)
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
# logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
return True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def add_clip_scores_and_embeddings(src, dest, batch_size=512):
dataset = wds.DataPipeline(
wds.SimpleShardList(str(src)),
wds.split_by_worker,
wds.tarfile_to_samples(),
wds.rename(
__key__="__key__",
dino_image="jpg",
image="jpg",
street_clip="street_clip.npy",
json="json",
),
wds.decode(
ImageHandler("pilrgb", ["dino_image"])
), # avoid encoding decoding jpeg for true
wds.map_dict(
dino_image=augmentation_dinov2,
image=lambda x: x,
street_clip=lambda x: x,
json=lambda x: x,
),
wds.to_tuple(
"__key__",
"dino_image",
"street_clip",
"image",
"json",
),
wds.batched(batch_size),
)
loader = wds.WebLoader(dataset, num_workers=8, batch_size=None)
with wds.TarWriter(str(dest)) as sink:
for batch in tqdm(loader, total=10000 // batch_size):
(
keys,
dino_image,
street_clip,
image,
json,
) = batch
dino_image = dino_image.to(device)
with torch.no_grad():
dino_embedding = dinov2_model(dino_image).cpu().numpy()
for i in range(len(keys)):
sample = {
"__key__": keys[i],
"jpg": image[i],
"street_clip.npy": street_clip[i],
"json": json[i],
"dinov2_vitl14_registers.npy": dino_embedding[i],
}
sink.write(sample)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src", help="path to source files")
parser.add_argument("--dest", help="path to destination files")
parser.add_argument("--shard_id", help="shard id")
args = parser.parse_args()
src = Path(args.src)
list_of_shards = list(src.glob("*.tar"))
list_of_shards.sort()
shard = str(list_of_shards[int(args.shard_id)]).split("/")[-1]
dest = Path(args.dest)
dest.mkdir(exist_ok=True, parents=True)
batch_size = 256
print(f"Loading {shard}")
tar_name = shard.split(".")[0]
src_shard = src / shard # f"{{{tar_name}...{tar_name}}}.tar"
print(f"Processing {src_shard} to {dest / shard}")
add_clip_scores_and_embeddings(src_shard, dest / shard, batch_size)