File size: 1,491 Bytes
c4c7cee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from transformers import CLIPProcessor


class ClipTransform(object):
    def __init__(self, split):
        self.transform = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")

    def __call__(self, x):
        # return self.transform(images=x, return_tensors="pt")["pixel_values"].squeeze(0)
        return self.transform(images=[x], return_tensors="pt")


if __name__ == "__main__":
    # sanity check
    import glob
    import torchvision.transforms as transforms
    from torchvision.utils import save_image
    from omegaconf import DictConfig, OmegaConf
    from hydra.utils import instantiate
    import torch
    from PIL import Image

    fast_clip_config = OmegaConf.load(
        "./configs/dataset/train_transform/fast_clip.yaml"
    )
    fast_clip_transform = instantiate(fast_clip_config)
    clip_transform = ClipTransform(None)

    img_paths = glob.glob("./datasets/osv5m/test/images/*.jpg")
    original_imgs, re_implemted_imgs, diff = [], [], []

    for i in range(16):
        img = Image.open(img_paths[i])
        clip_img = clip_transform(img)
        fast_clip_img = fast_clip_transform(img)
        original_imgs.append(clip_img)
        re_implemted_imgs.append(fast_clip_img)
        max_diff = (clip_img - fast_clip_img).abs()
        diff.append(max_diff)
        if max_diff.max() > 1e-5:
            print(max_diff.max())
    original_imgs = torch.stack(original_imgs)
    re_implemted_imgs = torch.stack(re_implemted_imgs)
    diff = torch.stack(diff)