grantpitt commited on
Commit
e511b99
1 Parent(s): 760d125
Files changed (5) hide show
  1. .gitignore +2 -0
  2. artwork_urls.npy +3 -0
  3. embeddings.npy +3 -0
  4. handler.py +31 -0
  5. requirements.txt +2 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .ipynb_checkpoints
2
+ compute.ipynb
artwork_urls.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba9f605c3852001ad53e4f7324e6f56ae88a2786ee40e19d10bec950d6192cd3
3
+ size 2152944
embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e07a3d10e239135feaa71e868c447e6e2bed382d37128da9d525a2cf0855f7c
3
+ size 89392256
handler.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import CLIPTokenizer, CLIPModel
3
+ import numpy as np
4
+ import os
5
+
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path=""):
9
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
10
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
11
+
12
+ self.artwork_urls = np.load(os.path.join(path, "artwork_urls.npy"), allow_pickle=True)
13
+ self.embeddings = np.load(os.path.join(path, "embeddings.npy"), allow_pickle=True)
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> List[float]:
16
+ """
17
+ data args:
18
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
19
+ kwargs
20
+ Return:
21
+ A :obj:`list` | `dict`: will be serialized and returned
22
+ """
23
+ inputs = self.tokenizer(data["inputs"], padding=True, return_tensors="pt")
24
+ text_features = self.model.get_text_features(**inputs)
25
+ input_embedding = text_features[0]
26
+ input_embedding = input_embedding / np.linalg.norm(input_embedding)
27
+
28
+ cos_score = self.embeddings @ input_embedding
29
+ top_10 = cos_score.argsort()[-100:][::-1]
30
+
31
+ return self.artwork_urls[top_10].tolist()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers==4.21.1
2
+ numpy==1.23.4