grantpitt commited on
Commit
4ddac9a
1 Parent(s): 0849e3d

param should be dict

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. handler.py +12 -13
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ test
handler.py CHANGED
@@ -1,31 +1,30 @@
1
- from typing import List
2
  from transformers import CLIPTokenizer, CLIPModel
3
 
4
 
5
- class EndpointHandler():
6
  def __init__(self, path=""):
7
  # self.model= load_model(path)
8
  hf_model_path = "openai/clip-vit-large-patch14"
9
  self.model = CLIPModel.from_pretrained(hf_model_path)
10
  self.tokenizer = CLIPTokenizer.from_pretrained(hf_model_path)
11
 
12
- def __call__(self, inputs: str) -> List[float]:
13
  """
14
- data args:
15
- inputs (:obj: `str` | `PIL.Image` | `np.array`)
16
- kwargs
17
- Return:
18
- A :obj:`list` | `dict`: will be serialized and returned
19
  """
20
 
21
  # pseudo
22
- # self.model(input)
23
- token_inputs = self.tokenizer([inputs], padding=True, return_tensors="pt")
24
  query_embed = self.model.get_text_features(**token_inputs)
25
  np_query_embed = query_embed.detach().cpu().numpy()[0].tolist()
26
  return np_query_embed
27
 
28
 
29
- # if __name__ == "__main__":
30
- # handler = EndpointHandler()
31
- # print(handler("a dog"))
 
1
+ from typing import Dict, List, Any
2
  from transformers import CLIPTokenizer, CLIPModel
3
 
4
 
5
+ class EndpointHandler:
6
  def __init__(self, path=""):
7
  # self.model= load_model(path)
8
  hf_model_path = "openai/clip-vit-large-patch14"
9
  self.model = CLIPModel.from_pretrained(hf_model_path)
10
  self.tokenizer = CLIPTokenizer.from_pretrained(hf_model_path)
11
 
12
+ def __call__(self, data: Dict[str, Any]) -> List[float]:
13
  """
14
+ data args:
15
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
16
+ kwargs
17
+ Return:
18
+ A :obj:`list` | `dict`: will be serialized and returned
19
  """
20
 
21
  # pseudo
22
+ token_inputs = self.tokenizer(data["inputs"], padding=True, return_tensors="pt")
 
23
  query_embed = self.model.get_text_features(**token_inputs)
24
  np_query_embed = query_embed.detach().cpu().numpy()[0].tolist()
25
  return np_query_embed
26
 
27
 
28
+ if __name__ == "__main__":
29
+ eh = EndpointHandler()
30
+ print(eh({"inputs": "a dog"}))