Spaces:
Running
Running
File size: 2,404 Bytes
c4c7cee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import os, sys
# Ajouter le répertoire racine au chemin
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
sys.path.append(root_dir)
import torch
from utils.image_processing import CenterCrop
from data.extract_embeddings.dataset_with_path import ImageWithPathDataset
import torch
from torchvision import transforms
from pathlib import Path
from tqdm import tqdm
import numpy as np
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--number_of_splits",
type=int,
help="Number of splits to process",
default=1,
)
parser.add_argument(
"--split_index",
type=int,
help="Index of the split to process",
default=0,
)
parser.add_argument(
"--input_path",
type=str,
help="Path to the input dataset",
)
parser.add_argument(
"--output_path",
type=str,
help="Path to the output dataset",
)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg")
model = torch.compile(model, mode="max-autotune")
model.eval()
model.to(device)
input_path = Path(args.input_path)
output_path = Path(args.output_path)
output_path.mkdir(exist_ok=True, parents=True)
augmentation = transforms.Compose(
[
CenterCrop(ratio="1:1"),
transforms.Resize(336, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)
dataset = ImageWithPathDataset(input_path, output_path, transform=augmentation)
dataset = torch.utils.data.Subset(
dataset,
range(
args.split_index * len(dataset) // args.number_of_splits,
(
(args.split_index + 1) * len(dataset) // args.number_of_splits
if args.split_index != args.number_of_splits - 1
else len(dataset)
),
),
)
batch_size = 128
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, num_workers=16, collate_fn=lambda x: zip(*x)
)
for images, output_emb_paths in tqdm(dataloader):
images = torch.stack(images, dim=0).to(device)
with torch.no_grad():
embeddings = model(images)
numpy_embeddings = embeddings.cpu().numpy()
for emb, output_emb_path in zip(numpy_embeddings, output_emb_paths):
np.save(f"{output_emb_path}.npy", emb)
|