Amitai Getzler commited on
Commit
4e6f8d5
1 Parent(s): 483a214

:art: Update

Browse files
Files changed (4) hide show
  1. __pycache__/handler.cpython-312.pyc +0 -0
  2. handler.py +186 -0
  3. requirements.txt +4 -0
  4. test.py +48 -0
__pycache__/handler.cpython-312.pyc ADDED
Binary file (10.2 kB). View file
 
handler.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base64 import b64decode
2
+ from io import BytesIO
3
+ import open_clip
4
+ import requests
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from typing import Dict, Any
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ self.model, self.preprocess_train, self.preprocess_val = (
13
+ open_clip.create_model_and_transforms(path)
14
+ )
15
+ self.tokenizer = open_clip.get_tokenizer(path)
16
+
17
+ def classify_image(self, candidate_labels, image):
18
+ def get_top_prediction(text_probs, labels):
19
+ max_index = text_probs[0].argmax().item()
20
+ return {
21
+ "label": labels[max_index],
22
+ "score": text_probs[0][max_index].item(),
23
+ }
24
+
25
+ top_prediction = None
26
+ for i in range(0, len(candidate_labels), 10):
27
+ batch_labels = candidate_labels[i : i + 10]
28
+ # Preprocess the image
29
+ image_tensor = self.preprocess_val(image).unsqueeze(0)
30
+ text = self.tokenizer(batch_labels)
31
+
32
+ with torch.no_grad(), torch.cuda.amp.autocast():
33
+ image_features = self.model.encode_image(image_tensor)
34
+ text_features = self.model.encode_text(text)
35
+ image_features /= image_features.norm(dim=-1, keepdim=True)
36
+ text_features /= text_features.norm(dim=-1, keepdim=True)
37
+
38
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
39
+
40
+ current_top = get_top_prediction(text_probs, batch_labels)
41
+ if top_prediction is None or current_top["score"] > top_prediction["score"]:
42
+ top_prediction = current_top
43
+
44
+ return {"label": top_prediction["label"]}
45
+
46
+ def combine_embeddings(
47
+ self, text_embeddings, image_embeddings, text_weight=0.5, image_weight=0.5
48
+ ):
49
+ """Combine text and image embeddings with specified weights."""
50
+ # Average text embeddings
51
+ if text_embeddings is not None:
52
+ avg_text_embedding = np.mean(np.vstack(text_embeddings), axis=0)
53
+ else:
54
+ avg_text_embedding = np.zeros_like(image_embeddings[0])
55
+
56
+ if image_embeddings is not None:
57
+ avg_image_embeddings = np.mean(np.vstack(image_embeddings), axis=0)
58
+ else:
59
+ avg_image_embeddings = np.zeros_like(text_embeddings[0])
60
+
61
+ # Combine text and image embeddings with specified weights
62
+ combined_embedding = np.average(
63
+ np.vstack((avg_text_embedding, avg_image_embeddings)),
64
+ axis=0,
65
+ weights=[text_weight, image_weight],
66
+ )
67
+ return combined_embedding
68
+
69
+ def average_text(self, doc):
70
+ text_chunks = [
71
+ " ".join(doc.split(" ")[i : i + 40])
72
+ for i in range(0, len(doc.split(" ")), 40)
73
+ ]
74
+ text_embeddings = []
75
+ for chunk in text_chunks:
76
+ inputs = self.tokenizer(chunk)
77
+ text_features = self.model.encode_text(inputs)
78
+ text_features /= text_features.norm(dim=-1, keepdim=True)
79
+ text_embeddings.append(text_features.detach().squeeze().numpy())
80
+ combined = self.combine_embeddings(
81
+ text_embeddings, None, text_weight=1, image_weight=0
82
+ )
83
+ return combined
84
+
85
+ def embedd_image(self, doc) -> list:
86
+ if not isinstance(doc, str):
87
+ image = doc.get("image")
88
+ if "https://" in image:
89
+ image = image.split("|")
90
+ # response = requests.get(image)
91
+ image = [
92
+ Image.open(BytesIO(response.content))
93
+ for response in [requests.get(image) for image in image]
94
+ ][0]
95
+ # Simulate generating embeddings
96
+ image = self.preprocess_val(image).unsqueeze(0)
97
+ image_features = self.model.encode_image(image)
98
+ image_features /= image_features.norm(dim=-1, keepdim=True)
99
+ image_embedding = image_features.detach().squeeze().numpy()
100
+ if doc.get("description", "") == "":
101
+ print("empty description. Going with image alone")
102
+ return image_embedding.tolist()
103
+ else:
104
+ average_texts = self.average_text(doc.get("description"))
105
+ combined = self.combine_embeddings(
106
+ [average_texts],
107
+ [image_embedding],
108
+ text_weight=0.5,
109
+ image_weight=0.5,
110
+ )
111
+ return combined.tolist()
112
+ elif isinstance(doc, str):
113
+ return self.average_text(doc).tolist()
114
+
115
+ def process_batch(self, batch) -> object:
116
+ try:
117
+ batch = batch.get("batch")
118
+ # Validate the batch input
119
+ if not isinstance(batch, list):
120
+ return "Invalid input: batch must be an array of strings.", 400
121
+ embeddings = [self.embedd_image(item) for item in batch]
122
+ # Send the response with the embeddings array
123
+ return embeddings
124
+ except Exception as e:
125
+ print("Error processing request", e)
126
+ return "An error occurred while processing the request.", 500
127
+
128
+ def base64_image_to_pil(self, base64_str) -> Image:
129
+ image_data = b64decode(base64_str)
130
+ image_buffer = BytesIO(image_data)
131
+ image = Image.open(image_buffer)
132
+ return image
133
+
134
+ def __call__(self, data: Any) -> Dict[str, Any]:
135
+ """
136
+ Process the input data for either classification or embedding generation.
137
+
138
+ Args:
139
+ data (:obj:`dict`): A dictionary containing the input data and parameters for inference.
140
+ For classification:
141
+ {
142
+ "type": "classify",
143
+ "inputs": {
144
+ "candidates": :obj:`list[str]`,
145
+ "image": :obj:`str` # URL or base64 encoded image
146
+ }
147
+ }
148
+ For embedding:
149
+ {
150
+ "type": "embedd",
151
+ "batch": :obj:`list[str | dict[str, str]]` # Text or image+description
152
+ }
153
+
154
+ Returns:
155
+ :obj:`dict`: The result of the operation.
156
+ For classification:
157
+ {
158
+ "label": :obj:`str` # The predicted label
159
+ }
160
+ For embedding:
161
+ {
162
+ "embeddings": :obj:`list[list[float]]` # List of embeddings
163
+ }
164
+
165
+ Raises:
166
+ :obj:`Exception`: If an error occurs during processing.
167
+ """
168
+ inputs = data.pop("inputs", data)
169
+ type = data.pop("type", "embedd") # Or classify
170
+ print("type is", type)
171
+ print("input is", inputs)
172
+ if type == "classify":
173
+ candidate_labels = inputs["candidates"]
174
+ image = (
175
+ Image.open(BytesIO(requests.get(inputs["image"]).content))
176
+ if "https://" in inputs["image"]
177
+ else self.base64_image_to_pil(inputs["image"])
178
+ )
179
+ response = self.classify_image(candidate_labels, image)
180
+ return response
181
+ elif type == "embedd":
182
+ try:
183
+ embeddings = self.process_batch(inputs)
184
+ return {"embeddings": embeddings}
185
+ except Exception as e:
186
+ return e
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ open_clip_torch
2
+ numpy
3
+ pillow
4
+ requests
test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ # init handler
4
+ my_handler = EndpointHandler(path="hf-hub:Styld/marqo-fashionSigLIP")
5
+
6
+ # prepare sample payload
7
+ embedding_input = {
8
+ "inputs": {
9
+ "batch": [
10
+ {
11
+ "image": "https://lp2.hm.com/hmgoepprod?set=source[/23/ab/23ab27480dd2dfc7007745402d3b8caaf756ee70.jpg],origin[dam],category[],type[DESCRIPTIVESTILLLIFE],res[m],hmver[2]&call=url[file:/product/style]",
12
+ "description": "test",
13
+ }
14
+ ]
15
+ },
16
+ "type": "embedd",
17
+ }
18
+ classify_input = {
19
+ "inputs": {
20
+ "candidates": [
21
+ "Bohemian",
22
+ "Vintage",
23
+ "Streetwear",
24
+ "Preppy",
25
+ "Minimalist",
26
+ "Glamorous",
27
+ "Punk",
28
+ "Romantic",
29
+ "Classic",
30
+ "Avant-garde",
31
+ "Grunge",
32
+ "Retro",
33
+ "Gothic",
34
+ "Hippie",
35
+ "Eco-friendly",
36
+ ],
37
+ "image": "https://static.zara.net/assets/public/bb1f/0983/a29f44e18ec9/3b7dd5791c67/05575420427-e1/05575420427-e1.jpg?ts=1708614949903&w=1126",
38
+ },
39
+ "type": "classify",
40
+ }
41
+
42
+ # test the handler
43
+ embedd_pred = my_handler(embedding_input)
44
+ classify_pred = my_handler(classify_input)
45
+
46
+ # show results
47
+ print("embedd_pred", embedd_pred)
48
+ print("classify_pred", classify_pred)