sohojoe commited on
Commit
2afa949
1 Parent(s): ed1e314

use binary for all

Browse files
experimental/clip_api_app.py CHANGED
@@ -1,9 +1,12 @@
1
- from typing import List
 
 
2
  import numpy as np
3
  import torch
 
 
4
  import ray
5
  from ray import serve
6
- from PIL import Image
7
  from clip_retrieval.load_clip import load_clip, get_tokenizer
8
  # from clip_retrieval.clip_client import ClipClient, Modality
9
 
@@ -21,14 +24,11 @@ class CLIPTransform:
21
 
22
  print ("using device", self.device)
23
 
24
- @serve.batch(max_batch_size=32)
25
- # def text_to_embeddings(self, prompts: List[str]) -> torch.Tensor:
26
- def text_to_embeddings(self, prompts: List[str]) -> List[np.ndarray]:
27
- text = self.tokenizer(prompts).to(self.device)
28
  with torch.no_grad():
29
  prompt_embededdings = self.model.encode_text(text)
30
  prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
31
- prompt_embededdings = prompt_embededdings.cpu().numpy().tolist()
32
  return(prompt_embededdings)
33
 
34
  def image_to_embeddings(self, input_im):
@@ -45,31 +45,91 @@ class CLIPTransform:
45
  image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
46
  return(image_embeddings)
47
 
48
- # async def __call__(self, http_request: Request) -> str:
49
- # request = await http_request.json()
50
- # # print(type(request))
51
- # # print(str(request))
52
- # # switch based if we are using text or image
53
- # embeddings = None
54
- # if "text" in request:
55
- # prompt = request["text"]
56
- # embeddings = self.text_to_embeddings(prompt)
57
- # elif "image" in request:
58
- # image_url = request["image_url"]
59
- # # download image from url
60
- # import requests
61
- # from io import BytesIO
62
- # input_image = Image.open(BytesIO(image_url))
63
- # input_image = input_image.convert('RGB')
64
- # input_image = np.array(input_image)
65
- # embeddings = self.image_to_embeddings(input_image)
66
- # elif "preprocessed_image" in request:
67
- # prepro = request["preprocessed_image"]
68
- # # create torch tensor on the device
69
- # prepro = torch.tensor(prepro).to(self.device)
70
- # embeddings = self.preprocessed_image_to_emdeddings(prepro)
71
- # else:
72
- # raise Exception("Invalid request")
73
- # return embeddings.cpu().numpy().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  deployment_graph = CLIPTransform.bind()
 
1
+ # File name: model.py
2
+ import json
3
+ import os
4
  import numpy as np
5
  import torch
6
+ from starlette.requests import Request
7
+ from PIL import Image
8
  import ray
9
  from ray import serve
 
10
  from clip_retrieval.load_clip import load_clip, get_tokenizer
11
  # from clip_retrieval.clip_client import ClipClient, Modality
12
 
 
24
 
25
  print ("using device", self.device)
26
 
27
+ def text_to_embeddings(self, prompt):
28
+ text = self.tokenizer([prompt]).to(self.device)
 
 
29
  with torch.no_grad():
30
  prompt_embededdings = self.model.encode_text(text)
31
  prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
 
32
  return(prompt_embededdings)
33
 
34
  def image_to_embeddings(self, input_im):
 
45
  image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
46
  return(image_embeddings)
47
 
48
+ async def __call__(self, http_request: Request) -> str:
49
+ form_data = await http_request.form()
50
+
51
+ embeddings = None
52
+ if "text" in form_data:
53
+ prompt = (await form_data["text"].read()).decode()
54
+ print (type(prompt))
55
+ print (str(prompt))
56
+ embeddings = self.text_to_embeddings(prompt)
57
+ elif "image_url" in form_data:
58
+ image_url = (await form_data["image_url"].read()).decode()
59
+ # download image from url
60
+ import requests
61
+ from io import BytesIO
62
+ image_bytes = requests.get(image_url).content
63
+ input_image = Image.open(BytesIO(image_bytes))
64
+ input_image = input_image.convert('RGB')
65
+ input_image = np.array(input_image)
66
+ embeddings = self.image_to_embeddings(input_image)
67
+ elif "preprocessed_image" in form_data:
68
+ tensor_bytes = await form_data["preprocessed_image"].read()
69
+ shape_bytes = await form_data["shape"].read()
70
+ dtype_bytes = await form_data["dtype"].read()
71
+
72
+ # Convert bytes back to original form
73
+ dtype_mapping = {
74
+ "torch.float32": torch.float32,
75
+ "torch.float64": torch.float64,
76
+ "torch.float16": torch.float16,
77
+ "torch.uint8": torch.uint8,
78
+ "torch.int8": torch.int8,
79
+ "torch.int16": torch.int16,
80
+ "torch.int32": torch.int32,
81
+ "torch.int64": torch.int64,
82
+ torch.float32: np.float32,
83
+ torch.float64: np.float64,
84
+ torch.float16: np.float16,
85
+ torch.uint8: np.uint8,
86
+ torch.int8: np.int8,
87
+ torch.int16: np.int16,
88
+ torch.int32: np.int32,
89
+ torch.int64: np.int64,
90
+ # add more if needed
91
+ }
92
+ dtype_str = dtype_bytes.decode()
93
+ dtype_torch = dtype_mapping[dtype_str]
94
+ dtype_numpy = dtype_mapping[dtype_torch]
95
+ # shape = np.frombuffer(shape_bytes, dtype=np.int64)
96
+ # TODO: fix shape so it is passed nicely
97
+ shape = tuple([1, 3, 224, 224])
98
+
99
+ tensor_numpy = np.frombuffer(tensor_bytes, dtype=dtype_numpy).reshape(shape)
100
+ tensor = torch.from_numpy(tensor_numpy)
101
+ prepro = tensor.to(self.device)
102
+ embeddings = self.preprocessed_image_to_emdeddings(prepro)
103
+ else:
104
+ print ("Invalid request")
105
+ raise Exception("Invalid request")
106
+ return embeddings.cpu().numpy().tolist()
107
+
108
+ request = await http_request.json()
109
+ # print(type(request))
110
+ # print(str(request))
111
+ # switch based if we are using text or image
112
+ embeddings = None
113
+ if "text" in request:
114
+ prompt = request["text"]
115
+ embeddings = self.text_to_embeddings(prompt)
116
+ elif "image_url" in request:
117
+ image_url = request["image_url"]
118
+ # download image from url
119
+ import requests
120
+ from io import BytesIO
121
+ image_bytes = requests.get(image_url).content
122
+ input_image = Image.open(BytesIO(image_bytes))
123
+ input_image = input_image.convert('RGB')
124
+ input_image = np.array(input_image)
125
+ embeddings = self.image_to_embeddings(input_image)
126
+ elif "preprocessed_image" in request:
127
+ prepro = request["preprocessed_image"]
128
+ # create torch tensor on the device
129
+ prepro = torch.tensor(prepro).to(self.device)
130
+ embeddings = self.preprocessed_image_to_emdeddings(prepro)
131
+ else:
132
+ raise Exception("Invalid request")
133
+ return embeddings.cpu().numpy().tolist()
134
 
135
  deployment_graph = CLIPTransform.bind()
experimental/clip_app_client.py CHANGED
@@ -38,30 +38,41 @@ def preprocess_image(image_url):
38
  # convert image to numpy array
39
  input_image = np.array(input_image)
40
  input_im = Image.fromarray(input_image)
41
- prepro = preprocess(input_im).unsqueeze(0).to(device)
42
  return prepro
43
 
44
  preprocessed_image = preprocess_image(test_image_url)
45
 
46
  def send_text_request(number):
47
- data = {"text": english_text}
 
 
48
  url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
49
- response = requests.post(url, json=data)
50
  embeddings = response.text
51
  return number, embeddings
52
 
53
  def send_image_url_request(number):
54
- data = {"image_url": test_image_url}
 
 
55
  url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
56
- response = requests.post(url, json=data)
57
  embeddings = response.text
58
  return number, embeddings
59
 
60
  def send_preprocessed_image_request(number):
61
- nested_list = preprocessed_image.tolist()
62
- data = {"preprocessed_image": nested_list}
 
 
 
 
 
 
 
63
  url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
64
- response = requests.post(url, json=data)
65
  embeddings = response.text
66
  return number, embeddings
67
 
@@ -80,7 +91,7 @@ def process(numbers, send_func, max_workers=10):
80
  # print (f"{n_result} : {len(result[0])}")
81
 
82
  if __name__ == "__main__":
83
- n_calls = 10000
84
 
85
  # test text
86
  # n_calls = 1
@@ -95,8 +106,6 @@ if __name__ == "__main__":
95
  print(f" Average time taken: {avg_time_ms:.2f} ms")
96
  print(f" Number of calls per second: {calls_per_sec:.2f}")
97
 
98
- n_calls = 100
99
-
100
  # test image url
101
  # n_calls = 1
102
  numbers = list(range(n_calls))
@@ -119,6 +128,6 @@ if __name__ == "__main__":
119
  total_time = end_time - start_time
120
  avg_time_ms = total_time / n_calls * 1000
121
  calls_per_sec = n_calls / total_time
122
- print(f"Text...")
123
  print(f" Average time taken: {avg_time_ms:.2f} ms")
124
  print(f" Number of calls per second: {calls_per_sec:.2f}")
 
38
  # convert image to numpy array
39
  input_image = np.array(input_image)
40
  input_im = Image.fromarray(input_image)
41
+ prepro = preprocess(input_im).unsqueeze(0).cpu()
42
  return prepro
43
 
44
  preprocessed_image = preprocess_image(test_image_url)
45
 
46
  def send_text_request(number):
47
+ payload = {
48
+ "text": ('str', english_text, 'application/octet-stream'),
49
+ }
50
  url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
51
+ response = requests.post(url, files=payload)
52
  embeddings = response.text
53
  return number, embeddings
54
 
55
  def send_image_url_request(number):
56
+ payload = {
57
+ "image_url": ('str', test_image_url, 'application/octet-stream'),
58
+ }
59
  url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
60
+ response = requests.post(url, files=payload)
61
  embeddings = response.text
62
  return number, embeddings
63
 
64
  def send_preprocessed_image_request(number):
65
+ key = "preprocessed_image"
66
+ data_bytes = preprocessed_image.numpy().tobytes()
67
+ shape_bytes = np.array(preprocessed_image.shape).tobytes()
68
+ dtype_bytes = str(preprocessed_image.dtype).encode()
69
+ payload = {
70
+ key: ('tensor', data_bytes, 'application/octet-stream'),
71
+ 'shape': ('shape', shape_bytes, 'application/octet-stream'),
72
+ 'dtype': ('dtype', dtype_bytes, 'application/octet-stream'),
73
+ }
74
  url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
75
+ response = requests.post(url, files=payload)
76
  embeddings = response.text
77
  return number, embeddings
78
 
 
91
  # print (f"{n_result} : {len(result[0])}")
92
 
93
  if __name__ == "__main__":
94
+ n_calls = 300
95
 
96
  # test text
97
  # n_calls = 1
 
106
  print(f" Average time taken: {avg_time_ms:.2f} ms")
107
  print(f" Number of calls per second: {calls_per_sec:.2f}")
108
 
 
 
109
  # test image url
110
  # n_calls = 1
111
  numbers = list(range(n_calls))
 
128
  total_time = end_time - start_time
129
  avg_time_ms = total_time / n_calls * 1000
130
  calls_per_sec = n_calls / total_time
131
+ print(f"Preprocessed image...")
132
  print(f" Average time taken: {avg_time_ms:.2f} ms")
133
  print(f" Number of calls per second: {calls_per_sec:.2f}")