Amitai Getzler
commited on
Commit
•
4e6f8d5
1
Parent(s):
483a214
:art: Update
Browse files- __pycache__/handler.cpython-312.pyc +0 -0
- handler.py +186 -0
- requirements.txt +4 -0
- test.py +48 -0
__pycache__/handler.cpython-312.pyc
ADDED
Binary file (10.2 kB). View file
|
|
handler.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from base64 import b64decode
|
2 |
+
from io import BytesIO
|
3 |
+
import open_clip
|
4 |
+
import requests
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from typing import Dict, Any
|
9 |
+
|
10 |
+
class EndpointHandler:
|
11 |
+
def __init__(self, path=""):
|
12 |
+
self.model, self.preprocess_train, self.preprocess_val = (
|
13 |
+
open_clip.create_model_and_transforms(path)
|
14 |
+
)
|
15 |
+
self.tokenizer = open_clip.get_tokenizer(path)
|
16 |
+
|
17 |
+
def classify_image(self, candidate_labels, image):
|
18 |
+
def get_top_prediction(text_probs, labels):
|
19 |
+
max_index = text_probs[0].argmax().item()
|
20 |
+
return {
|
21 |
+
"label": labels[max_index],
|
22 |
+
"score": text_probs[0][max_index].item(),
|
23 |
+
}
|
24 |
+
|
25 |
+
top_prediction = None
|
26 |
+
for i in range(0, len(candidate_labels), 10):
|
27 |
+
batch_labels = candidate_labels[i : i + 10]
|
28 |
+
# Preprocess the image
|
29 |
+
image_tensor = self.preprocess_val(image).unsqueeze(0)
|
30 |
+
text = self.tokenizer(batch_labels)
|
31 |
+
|
32 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
33 |
+
image_features = self.model.encode_image(image_tensor)
|
34 |
+
text_features = self.model.encode_text(text)
|
35 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
36 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
37 |
+
|
38 |
+
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
39 |
+
|
40 |
+
current_top = get_top_prediction(text_probs, batch_labels)
|
41 |
+
if top_prediction is None or current_top["score"] > top_prediction["score"]:
|
42 |
+
top_prediction = current_top
|
43 |
+
|
44 |
+
return {"label": top_prediction["label"]}
|
45 |
+
|
46 |
+
def combine_embeddings(
|
47 |
+
self, text_embeddings, image_embeddings, text_weight=0.5, image_weight=0.5
|
48 |
+
):
|
49 |
+
"""Combine text and image embeddings with specified weights."""
|
50 |
+
# Average text embeddings
|
51 |
+
if text_embeddings is not None:
|
52 |
+
avg_text_embedding = np.mean(np.vstack(text_embeddings), axis=0)
|
53 |
+
else:
|
54 |
+
avg_text_embedding = np.zeros_like(image_embeddings[0])
|
55 |
+
|
56 |
+
if image_embeddings is not None:
|
57 |
+
avg_image_embeddings = np.mean(np.vstack(image_embeddings), axis=0)
|
58 |
+
else:
|
59 |
+
avg_image_embeddings = np.zeros_like(text_embeddings[0])
|
60 |
+
|
61 |
+
# Combine text and image embeddings with specified weights
|
62 |
+
combined_embedding = np.average(
|
63 |
+
np.vstack((avg_text_embedding, avg_image_embeddings)),
|
64 |
+
axis=0,
|
65 |
+
weights=[text_weight, image_weight],
|
66 |
+
)
|
67 |
+
return combined_embedding
|
68 |
+
|
69 |
+
def average_text(self, doc):
|
70 |
+
text_chunks = [
|
71 |
+
" ".join(doc.split(" ")[i : i + 40])
|
72 |
+
for i in range(0, len(doc.split(" ")), 40)
|
73 |
+
]
|
74 |
+
text_embeddings = []
|
75 |
+
for chunk in text_chunks:
|
76 |
+
inputs = self.tokenizer(chunk)
|
77 |
+
text_features = self.model.encode_text(inputs)
|
78 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
79 |
+
text_embeddings.append(text_features.detach().squeeze().numpy())
|
80 |
+
combined = self.combine_embeddings(
|
81 |
+
text_embeddings, None, text_weight=1, image_weight=0
|
82 |
+
)
|
83 |
+
return combined
|
84 |
+
|
85 |
+
def embedd_image(self, doc) -> list:
|
86 |
+
if not isinstance(doc, str):
|
87 |
+
image = doc.get("image")
|
88 |
+
if "https://" in image:
|
89 |
+
image = image.split("|")
|
90 |
+
# response = requests.get(image)
|
91 |
+
image = [
|
92 |
+
Image.open(BytesIO(response.content))
|
93 |
+
for response in [requests.get(image) for image in image]
|
94 |
+
][0]
|
95 |
+
# Simulate generating embeddings
|
96 |
+
image = self.preprocess_val(image).unsqueeze(0)
|
97 |
+
image_features = self.model.encode_image(image)
|
98 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
99 |
+
image_embedding = image_features.detach().squeeze().numpy()
|
100 |
+
if doc.get("description", "") == "":
|
101 |
+
print("empty description. Going with image alone")
|
102 |
+
return image_embedding.tolist()
|
103 |
+
else:
|
104 |
+
average_texts = self.average_text(doc.get("description"))
|
105 |
+
combined = self.combine_embeddings(
|
106 |
+
[average_texts],
|
107 |
+
[image_embedding],
|
108 |
+
text_weight=0.5,
|
109 |
+
image_weight=0.5,
|
110 |
+
)
|
111 |
+
return combined.tolist()
|
112 |
+
elif isinstance(doc, str):
|
113 |
+
return self.average_text(doc).tolist()
|
114 |
+
|
115 |
+
def process_batch(self, batch) -> object:
|
116 |
+
try:
|
117 |
+
batch = batch.get("batch")
|
118 |
+
# Validate the batch input
|
119 |
+
if not isinstance(batch, list):
|
120 |
+
return "Invalid input: batch must be an array of strings.", 400
|
121 |
+
embeddings = [self.embedd_image(item) for item in batch]
|
122 |
+
# Send the response with the embeddings array
|
123 |
+
return embeddings
|
124 |
+
except Exception as e:
|
125 |
+
print("Error processing request", e)
|
126 |
+
return "An error occurred while processing the request.", 500
|
127 |
+
|
128 |
+
def base64_image_to_pil(self, base64_str) -> Image:
|
129 |
+
image_data = b64decode(base64_str)
|
130 |
+
image_buffer = BytesIO(image_data)
|
131 |
+
image = Image.open(image_buffer)
|
132 |
+
return image
|
133 |
+
|
134 |
+
def __call__(self, data: Any) -> Dict[str, Any]:
|
135 |
+
"""
|
136 |
+
Process the input data for either classification or embedding generation.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
data (:obj:`dict`): A dictionary containing the input data and parameters for inference.
|
140 |
+
For classification:
|
141 |
+
{
|
142 |
+
"type": "classify",
|
143 |
+
"inputs": {
|
144 |
+
"candidates": :obj:`list[str]`,
|
145 |
+
"image": :obj:`str` # URL or base64 encoded image
|
146 |
+
}
|
147 |
+
}
|
148 |
+
For embedding:
|
149 |
+
{
|
150 |
+
"type": "embedd",
|
151 |
+
"batch": :obj:`list[str | dict[str, str]]` # Text or image+description
|
152 |
+
}
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
:obj:`dict`: The result of the operation.
|
156 |
+
For classification:
|
157 |
+
{
|
158 |
+
"label": :obj:`str` # The predicted label
|
159 |
+
}
|
160 |
+
For embedding:
|
161 |
+
{
|
162 |
+
"embeddings": :obj:`list[list[float]]` # List of embeddings
|
163 |
+
}
|
164 |
+
|
165 |
+
Raises:
|
166 |
+
:obj:`Exception`: If an error occurs during processing.
|
167 |
+
"""
|
168 |
+
inputs = data.pop("inputs", data)
|
169 |
+
type = data.pop("type", "embedd") # Or classify
|
170 |
+
print("type is", type)
|
171 |
+
print("input is", inputs)
|
172 |
+
if type == "classify":
|
173 |
+
candidate_labels = inputs["candidates"]
|
174 |
+
image = (
|
175 |
+
Image.open(BytesIO(requests.get(inputs["image"]).content))
|
176 |
+
if "https://" in inputs["image"]
|
177 |
+
else self.base64_image_to_pil(inputs["image"])
|
178 |
+
)
|
179 |
+
response = self.classify_image(candidate_labels, image)
|
180 |
+
return response
|
181 |
+
elif type == "embedd":
|
182 |
+
try:
|
183 |
+
embeddings = self.process_batch(inputs)
|
184 |
+
return {"embeddings": embeddings}
|
185 |
+
except Exception as e:
|
186 |
+
return e
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
open_clip_torch
|
2 |
+
numpy
|
3 |
+
pillow
|
4 |
+
requests
|
test.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from handler import EndpointHandler
|
2 |
+
|
3 |
+
# init handler
|
4 |
+
my_handler = EndpointHandler(path="hf-hub:Styld/marqo-fashionSigLIP")
|
5 |
+
|
6 |
+
# prepare sample payload
|
7 |
+
embedding_input = {
|
8 |
+
"inputs": {
|
9 |
+
"batch": [
|
10 |
+
{
|
11 |
+
"image": "https://lp2.hm.com/hmgoepprod?set=source[/23/ab/23ab27480dd2dfc7007745402d3b8caaf756ee70.jpg],origin[dam],category[],type[DESCRIPTIVESTILLLIFE],res[m],hmver[2]&call=url[file:/product/style]",
|
12 |
+
"description": "test",
|
13 |
+
}
|
14 |
+
]
|
15 |
+
},
|
16 |
+
"type": "embedd",
|
17 |
+
}
|
18 |
+
classify_input = {
|
19 |
+
"inputs": {
|
20 |
+
"candidates": [
|
21 |
+
"Bohemian",
|
22 |
+
"Vintage",
|
23 |
+
"Streetwear",
|
24 |
+
"Preppy",
|
25 |
+
"Minimalist",
|
26 |
+
"Glamorous",
|
27 |
+
"Punk",
|
28 |
+
"Romantic",
|
29 |
+
"Classic",
|
30 |
+
"Avant-garde",
|
31 |
+
"Grunge",
|
32 |
+
"Retro",
|
33 |
+
"Gothic",
|
34 |
+
"Hippie",
|
35 |
+
"Eco-friendly",
|
36 |
+
],
|
37 |
+
"image": "https://static.zara.net/assets/public/bb1f/0983/a29f44e18ec9/3b7dd5791c67/05575420427-e1/05575420427-e1.jpg?ts=1708614949903&w=1126",
|
38 |
+
},
|
39 |
+
"type": "classify",
|
40 |
+
}
|
41 |
+
|
42 |
+
# test the handler
|
43 |
+
embedd_pred = my_handler(embedding_input)
|
44 |
+
classify_pred = my_handler(classify_input)
|
45 |
+
|
46 |
+
# show results
|
47 |
+
print("embedd_pred", embedd_pred)
|
48 |
+
print("classify_pred", classify_pred)
|