use binary for all
Browse files- experimental/clip_api_app.py +93 -33
- experimental/clip_app_client.py +21 -12
experimental/clip_api_app.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
-
|
|
|
|
|
2 |
import numpy as np
|
3 |
import torch
|
|
|
|
|
4 |
import ray
|
5 |
from ray import serve
|
6 |
-
from PIL import Image
|
7 |
from clip_retrieval.load_clip import load_clip, get_tokenizer
|
8 |
# from clip_retrieval.clip_client import ClipClient, Modality
|
9 |
|
@@ -21,14 +24,11 @@ class CLIPTransform:
|
|
21 |
|
22 |
print ("using device", self.device)
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
def text_to_embeddings(self, prompts: List[str]) -> List[np.ndarray]:
|
27 |
-
text = self.tokenizer(prompts).to(self.device)
|
28 |
with torch.no_grad():
|
29 |
prompt_embededdings = self.model.encode_text(text)
|
30 |
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
|
31 |
-
prompt_embededdings = prompt_embededdings.cpu().numpy().tolist()
|
32 |
return(prompt_embededdings)
|
33 |
|
34 |
def image_to_embeddings(self, input_im):
|
@@ -45,31 +45,91 @@ class CLIPTransform:
|
|
45 |
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
|
46 |
return(image_embeddings)
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
deployment_graph = CLIPTransform.bind()
|
|
|
1 |
+
# File name: model.py
|
2 |
+
import json
|
3 |
+
import os
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
+
from starlette.requests import Request
|
7 |
+
from PIL import Image
|
8 |
import ray
|
9 |
from ray import serve
|
|
|
10 |
from clip_retrieval.load_clip import load_clip, get_tokenizer
|
11 |
# from clip_retrieval.clip_client import ClipClient, Modality
|
12 |
|
|
|
24 |
|
25 |
print ("using device", self.device)
|
26 |
|
27 |
+
def text_to_embeddings(self, prompt):
|
28 |
+
text = self.tokenizer([prompt]).to(self.device)
|
|
|
|
|
29 |
with torch.no_grad():
|
30 |
prompt_embededdings = self.model.encode_text(text)
|
31 |
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
|
|
|
32 |
return(prompt_embededdings)
|
33 |
|
34 |
def image_to_embeddings(self, input_im):
|
|
|
45 |
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
|
46 |
return(image_embeddings)
|
47 |
|
48 |
+
async def __call__(self, http_request: Request) -> str:
|
49 |
+
form_data = await http_request.form()
|
50 |
+
|
51 |
+
embeddings = None
|
52 |
+
if "text" in form_data:
|
53 |
+
prompt = (await form_data["text"].read()).decode()
|
54 |
+
print (type(prompt))
|
55 |
+
print (str(prompt))
|
56 |
+
embeddings = self.text_to_embeddings(prompt)
|
57 |
+
elif "image_url" in form_data:
|
58 |
+
image_url = (await form_data["image_url"].read()).decode()
|
59 |
+
# download image from url
|
60 |
+
import requests
|
61 |
+
from io import BytesIO
|
62 |
+
image_bytes = requests.get(image_url).content
|
63 |
+
input_image = Image.open(BytesIO(image_bytes))
|
64 |
+
input_image = input_image.convert('RGB')
|
65 |
+
input_image = np.array(input_image)
|
66 |
+
embeddings = self.image_to_embeddings(input_image)
|
67 |
+
elif "preprocessed_image" in form_data:
|
68 |
+
tensor_bytes = await form_data["preprocessed_image"].read()
|
69 |
+
shape_bytes = await form_data["shape"].read()
|
70 |
+
dtype_bytes = await form_data["dtype"].read()
|
71 |
+
|
72 |
+
# Convert bytes back to original form
|
73 |
+
dtype_mapping = {
|
74 |
+
"torch.float32": torch.float32,
|
75 |
+
"torch.float64": torch.float64,
|
76 |
+
"torch.float16": torch.float16,
|
77 |
+
"torch.uint8": torch.uint8,
|
78 |
+
"torch.int8": torch.int8,
|
79 |
+
"torch.int16": torch.int16,
|
80 |
+
"torch.int32": torch.int32,
|
81 |
+
"torch.int64": torch.int64,
|
82 |
+
torch.float32: np.float32,
|
83 |
+
torch.float64: np.float64,
|
84 |
+
torch.float16: np.float16,
|
85 |
+
torch.uint8: np.uint8,
|
86 |
+
torch.int8: np.int8,
|
87 |
+
torch.int16: np.int16,
|
88 |
+
torch.int32: np.int32,
|
89 |
+
torch.int64: np.int64,
|
90 |
+
# add more if needed
|
91 |
+
}
|
92 |
+
dtype_str = dtype_bytes.decode()
|
93 |
+
dtype_torch = dtype_mapping[dtype_str]
|
94 |
+
dtype_numpy = dtype_mapping[dtype_torch]
|
95 |
+
# shape = np.frombuffer(shape_bytes, dtype=np.int64)
|
96 |
+
# TODO: fix shape so it is passed nicely
|
97 |
+
shape = tuple([1, 3, 224, 224])
|
98 |
+
|
99 |
+
tensor_numpy = np.frombuffer(tensor_bytes, dtype=dtype_numpy).reshape(shape)
|
100 |
+
tensor = torch.from_numpy(tensor_numpy)
|
101 |
+
prepro = tensor.to(self.device)
|
102 |
+
embeddings = self.preprocessed_image_to_emdeddings(prepro)
|
103 |
+
else:
|
104 |
+
print ("Invalid request")
|
105 |
+
raise Exception("Invalid request")
|
106 |
+
return embeddings.cpu().numpy().tolist()
|
107 |
+
|
108 |
+
request = await http_request.json()
|
109 |
+
# print(type(request))
|
110 |
+
# print(str(request))
|
111 |
+
# switch based if we are using text or image
|
112 |
+
embeddings = None
|
113 |
+
if "text" in request:
|
114 |
+
prompt = request["text"]
|
115 |
+
embeddings = self.text_to_embeddings(prompt)
|
116 |
+
elif "image_url" in request:
|
117 |
+
image_url = request["image_url"]
|
118 |
+
# download image from url
|
119 |
+
import requests
|
120 |
+
from io import BytesIO
|
121 |
+
image_bytes = requests.get(image_url).content
|
122 |
+
input_image = Image.open(BytesIO(image_bytes))
|
123 |
+
input_image = input_image.convert('RGB')
|
124 |
+
input_image = np.array(input_image)
|
125 |
+
embeddings = self.image_to_embeddings(input_image)
|
126 |
+
elif "preprocessed_image" in request:
|
127 |
+
prepro = request["preprocessed_image"]
|
128 |
+
# create torch tensor on the device
|
129 |
+
prepro = torch.tensor(prepro).to(self.device)
|
130 |
+
embeddings = self.preprocessed_image_to_emdeddings(prepro)
|
131 |
+
else:
|
132 |
+
raise Exception("Invalid request")
|
133 |
+
return embeddings.cpu().numpy().tolist()
|
134 |
|
135 |
deployment_graph = CLIPTransform.bind()
|
experimental/clip_app_client.py
CHANGED
@@ -38,30 +38,41 @@ def preprocess_image(image_url):
|
|
38 |
# convert image to numpy array
|
39 |
input_image = np.array(input_image)
|
40 |
input_im = Image.fromarray(input_image)
|
41 |
-
prepro = preprocess(input_im).unsqueeze(0).
|
42 |
return prepro
|
43 |
|
44 |
preprocessed_image = preprocess_image(test_image_url)
|
45 |
|
46 |
def send_text_request(number):
|
47 |
-
|
|
|
|
|
48 |
url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
|
49 |
-
response = requests.post(url,
|
50 |
embeddings = response.text
|
51 |
return number, embeddings
|
52 |
|
53 |
def send_image_url_request(number):
|
54 |
-
|
|
|
|
|
55 |
url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
|
56 |
-
response = requests.post(url,
|
57 |
embeddings = response.text
|
58 |
return number, embeddings
|
59 |
|
60 |
def send_preprocessed_image_request(number):
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
|
64 |
-
response = requests.post(url,
|
65 |
embeddings = response.text
|
66 |
return number, embeddings
|
67 |
|
@@ -80,7 +91,7 @@ def process(numbers, send_func, max_workers=10):
|
|
80 |
# print (f"{n_result} : {len(result[0])}")
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
-
n_calls =
|
84 |
|
85 |
# test text
|
86 |
# n_calls = 1
|
@@ -95,8 +106,6 @@ if __name__ == "__main__":
|
|
95 |
print(f" Average time taken: {avg_time_ms:.2f} ms")
|
96 |
print(f" Number of calls per second: {calls_per_sec:.2f}")
|
97 |
|
98 |
-
n_calls = 100
|
99 |
-
|
100 |
# test image url
|
101 |
# n_calls = 1
|
102 |
numbers = list(range(n_calls))
|
@@ -119,6 +128,6 @@ if __name__ == "__main__":
|
|
119 |
total_time = end_time - start_time
|
120 |
avg_time_ms = total_time / n_calls * 1000
|
121 |
calls_per_sec = n_calls / total_time
|
122 |
-
print(f"
|
123 |
print(f" Average time taken: {avg_time_ms:.2f} ms")
|
124 |
print(f" Number of calls per second: {calls_per_sec:.2f}")
|
|
|
38 |
# convert image to numpy array
|
39 |
input_image = np.array(input_image)
|
40 |
input_im = Image.fromarray(input_image)
|
41 |
+
prepro = preprocess(input_im).unsqueeze(0).cpu()
|
42 |
return prepro
|
43 |
|
44 |
preprocessed_image = preprocess_image(test_image_url)
|
45 |
|
46 |
def send_text_request(number):
|
47 |
+
payload = {
|
48 |
+
"text": ('str', english_text, 'application/octet-stream'),
|
49 |
+
}
|
50 |
url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
|
51 |
+
response = requests.post(url, files=payload)
|
52 |
embeddings = response.text
|
53 |
return number, embeddings
|
54 |
|
55 |
def send_image_url_request(number):
|
56 |
+
payload = {
|
57 |
+
"image_url": ('str', test_image_url, 'application/octet-stream'),
|
58 |
+
}
|
59 |
url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
|
60 |
+
response = requests.post(url, files=payload)
|
61 |
embeddings = response.text
|
62 |
return number, embeddings
|
63 |
|
64 |
def send_preprocessed_image_request(number):
|
65 |
+
key = "preprocessed_image"
|
66 |
+
data_bytes = preprocessed_image.numpy().tobytes()
|
67 |
+
shape_bytes = np.array(preprocessed_image.shape).tobytes()
|
68 |
+
dtype_bytes = str(preprocessed_image.dtype).encode()
|
69 |
+
payload = {
|
70 |
+
key: ('tensor', data_bytes, 'application/octet-stream'),
|
71 |
+
'shape': ('shape', shape_bytes, 'application/octet-stream'),
|
72 |
+
'dtype': ('dtype', dtype_bytes, 'application/octet-stream'),
|
73 |
+
}
|
74 |
url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
|
75 |
+
response = requests.post(url, files=payload)
|
76 |
embeddings = response.text
|
77 |
return number, embeddings
|
78 |
|
|
|
91 |
# print (f"{n_result} : {len(result[0])}")
|
92 |
|
93 |
if __name__ == "__main__":
|
94 |
+
n_calls = 300
|
95 |
|
96 |
# test text
|
97 |
# n_calls = 1
|
|
|
106 |
print(f" Average time taken: {avg_time_ms:.2f} ms")
|
107 |
print(f" Number of calls per second: {calls_per_sec:.2f}")
|
108 |
|
|
|
|
|
109 |
# test image url
|
110 |
# n_calls = 1
|
111 |
numbers = list(range(n_calls))
|
|
|
128 |
total_time = end_time - start_time
|
129 |
avg_time_ms = total_time / n_calls * 1000
|
130 |
calls_per_sec = n_calls / total_time
|
131 |
+
print(f"Preprocessed image...")
|
132 |
print(f" Average time taken: {avg_time_ms:.2f} ms")
|
133 |
print(f" Number of calls per second: {calls_per_sec:.2f}")
|