OwenElliott
commited on
Commit
•
b6c64a0
1
Parent(s):
f50de00
Upload 18 files
Browse files- amazon.json +0 -0
- app.py +276 -0
- cache_taxonomy_vectors.py +61 -0
- images/bike-helmet.png +0 -0
- images/coffee.png +0 -0
- images/cooking-book.jpg +0 -0
- images/cutting-board.png +0 -0
- images/flip-flops.jpg +0 -0
- images/grater.png +0 -0
- images/green-shirt.webp +0 -0
- images/hoop-earring.jpg +0 -0
- images/iron.png +0 -0
- images/laptop.png +0 -0
- images/notebook.png +0 -0
- images/red-dress.webp +0 -0
- images/runners.png +0 -0
- images/sleeping-bag.png +0 -0
- requirements.txt +6 -0
amazon.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import open_clip
|
3 |
+
from PIL import Image
|
4 |
+
import requests
|
5 |
+
import json
|
6 |
+
import gradio as gr
|
7 |
+
import pandas as pd
|
8 |
+
from io import BytesIO
|
9 |
+
import os
|
10 |
+
|
11 |
+
# Load the Amazon taxonomy from a JSON file
|
12 |
+
with open("amazon.json", "r") as f:
|
13 |
+
AMAZON_TAXONOMY = json.load(f)
|
14 |
+
|
15 |
+
|
16 |
+
base_model_name = "ViT-B-16"
|
17 |
+
model_base, _, preprocess_base = open_clip.create_model_and_transforms(base_model_name)
|
18 |
+
tokenizer_base = open_clip.get_tokenizer(base_model_name)
|
19 |
+
model_name_B = "hf-hub:Marqo/marqo-ecommerce-embeddings-B"
|
20 |
+
model_B, _, preprocess_B = open_clip.create_model_and_transforms(model_name_B)
|
21 |
+
tokenizer_B = open_clip.get_tokenizer(model_name_B)
|
22 |
+
model_name_L = "hf-hub:Marqo/marqo-ecommerce-embeddings-L"
|
23 |
+
model_L, _, preprocess_L = open_clip.create_model_and_transforms(model_name_L)
|
24 |
+
tokenizer_L = open_clip.get_tokenizer(model_name_L)
|
25 |
+
|
26 |
+
models = [base_model_name, model_name_B, model_name_L]
|
27 |
+
|
28 |
+
taxonomy_cache = {}
|
29 |
+
for model in models:
|
30 |
+
with open(f'{model.split("/")[-1]}.json', "r") as f:
|
31 |
+
taxonomy_cache[model] = json.load(f)
|
32 |
+
|
33 |
+
|
34 |
+
def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
35 |
+
numerator = (a * b).sum(dim=-1)
|
36 |
+
denominator = torch.linalg.norm(a, ord=2, dim=-1) * torch.linalg.norm(
|
37 |
+
b, ord=2, dim=-1
|
38 |
+
)
|
39 |
+
return 0.5 * (numerator / denominator + 1.0)
|
40 |
+
|
41 |
+
|
42 |
+
class BeamPath:
|
43 |
+
def __init__(self, path: list, cumulative_score: float, current_layer: dict | list):
|
44 |
+
self.path = path
|
45 |
+
self.cumulative_score = cumulative_score
|
46 |
+
self.current_layer = current_layer
|
47 |
+
|
48 |
+
def __repr__(self):
|
49 |
+
return f"BeamPath(path={self.path}, cumulative_score={self.cumulative_score})"
|
50 |
+
|
51 |
+
|
52 |
+
def _compute_similarities(classes: list, base_embedding: torch.Tensor, cache_key: str):
|
53 |
+
text_features = torch.tensor(
|
54 |
+
[taxonomy_cache[cache_key][class_name] for class_name in classes]
|
55 |
+
)
|
56 |
+
|
57 |
+
similarities = cosine_similarity(base_embedding, text_features)
|
58 |
+
return similarities.cpu().numpy()
|
59 |
+
|
60 |
+
|
61 |
+
def map_taxonomy(
|
62 |
+
base_image: Image.Image,
|
63 |
+
taxonomy: dict,
|
64 |
+
model,
|
65 |
+
tokenizer,
|
66 |
+
preprocess_val,
|
67 |
+
cache_key,
|
68 |
+
beam_width: int = 3,
|
69 |
+
) -> tuple[list[tuple[str, float]], float]:
|
70 |
+
image_tensor = preprocess_val(base_image).unsqueeze(0)
|
71 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
72 |
+
base_embedding = model.encode_image(image_tensor, normalize=True)
|
73 |
+
|
74 |
+
initial_path = BeamPath(path=[], cumulative_score=0.0, current_layer=taxonomy)
|
75 |
+
beam = [initial_path]
|
76 |
+
|
77 |
+
final_paths = []
|
78 |
+
is_first = True
|
79 |
+
while beam:
|
80 |
+
candidates = []
|
81 |
+
candidate_entries = []
|
82 |
+
|
83 |
+
for beam_path in beam:
|
84 |
+
layer = beam_path.current_layer
|
85 |
+
|
86 |
+
if isinstance(layer, dict):
|
87 |
+
classes = list(layer.keys())
|
88 |
+
elif isinstance(layer, list):
|
89 |
+
classes = layer
|
90 |
+
if classes == []:
|
91 |
+
final_paths.append(beam_path)
|
92 |
+
continue
|
93 |
+
else:
|
94 |
+
final_paths.append(beam_path)
|
95 |
+
continue
|
96 |
+
|
97 |
+
# current_path_class_names = [class_name for class_name, _ in beam_path.path]
|
98 |
+
|
99 |
+
for class_name in classes:
|
100 |
+
candidate_string = class_name
|
101 |
+
if isinstance(layer, dict):
|
102 |
+
next_layer = layer[class_name]
|
103 |
+
else:
|
104 |
+
next_layer = None
|
105 |
+
candidate_entries.append(
|
106 |
+
(candidate_string, class_name, beam_path, next_layer)
|
107 |
+
)
|
108 |
+
|
109 |
+
if not candidate_entries:
|
110 |
+
break
|
111 |
+
|
112 |
+
candidate_strings = [
|
113 |
+
candidate_string for candidate_string, _, _, _ in candidate_entries
|
114 |
+
]
|
115 |
+
|
116 |
+
similarities = _compute_similarities(
|
117 |
+
candidate_strings, base_embedding, cache_key
|
118 |
+
)
|
119 |
+
|
120 |
+
for (candidate_string, class_name, beam_path, next_layer), similarity in zip(
|
121 |
+
candidate_entries, similarities
|
122 |
+
):
|
123 |
+
new_path = beam_path.path + [(class_name, float(similarity))]
|
124 |
+
new_cumulative_score = beam_path.cumulative_score + similarity
|
125 |
+
candidate = BeamPath(
|
126 |
+
path=new_path,
|
127 |
+
cumulative_score=new_cumulative_score,
|
128 |
+
current_layer=next_layer,
|
129 |
+
)
|
130 |
+
candidates.append(candidate)
|
131 |
+
|
132 |
+
from collections import defaultdict
|
133 |
+
|
134 |
+
by_parents = defaultdict(list)
|
135 |
+
|
136 |
+
for candidate in candidates:
|
137 |
+
by_parents[candidate.path[0][0]].append(candidate)
|
138 |
+
|
139 |
+
beam = []
|
140 |
+
for parent in by_parents:
|
141 |
+
children = by_parents[parent]
|
142 |
+
children.sort(
|
143 |
+
key=lambda x: x.cumulative_score / len(x.path) + x.path[-1][1],
|
144 |
+
reverse=True,
|
145 |
+
)
|
146 |
+
if is_first:
|
147 |
+
beam.extend(children)
|
148 |
+
else:
|
149 |
+
beam.extend(children[:beam_width])
|
150 |
+
|
151 |
+
is_first = False
|
152 |
+
|
153 |
+
all_paths = beam + final_paths
|
154 |
+
|
155 |
+
if all_paths:
|
156 |
+
all_paths.sort(key=lambda x: x.cumulative_score / len(x.path), reverse=True)
|
157 |
+
best_path = all_paths[0]
|
158 |
+
return best_path.path, float(best_path.cumulative_score)
|
159 |
+
else:
|
160 |
+
return [], 0.0
|
161 |
+
|
162 |
+
|
163 |
+
# Function to classify image and map taxonomy
|
164 |
+
def classify_image(
|
165 |
+
image_input: Image.Image | None,
|
166 |
+
image_url: str | None,
|
167 |
+
model_size: str,
|
168 |
+
beam_width: int,
|
169 |
+
):
|
170 |
+
if image_input is not None:
|
171 |
+
image = image_input
|
172 |
+
elif image_url:
|
173 |
+
# Try to get image from URL
|
174 |
+
try:
|
175 |
+
response = requests.get(image_url)
|
176 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
177 |
+
except Exception as e:
|
178 |
+
return pd.DataFrame({"Error": [str(e)]})
|
179 |
+
else:
|
180 |
+
return pd.DataFrame(
|
181 |
+
{
|
182 |
+
"Error": [
|
183 |
+
"Please provide an image, an image URL, or select an example image"
|
184 |
+
]
|
185 |
+
}
|
186 |
+
)
|
187 |
+
|
188 |
+
# Select the model, tokenizer, and preprocess
|
189 |
+
if model_size == "marqo-ecommerce-embeddings-L":
|
190 |
+
key = "hf-hub:Marqo/marqo-ecommerce-embeddings-L"
|
191 |
+
model = model_L
|
192 |
+
preprocess_val = preprocess_L
|
193 |
+
tokenizer = tokenizer_L
|
194 |
+
elif model_size == "marqo-ecommerce-embeddings-B":
|
195 |
+
key = "hf-hub:Marqo/marqo-ecommerce-embeddings-B"
|
196 |
+
model = model_B
|
197 |
+
preprocess_val = preprocess_B
|
198 |
+
tokenizer = tokenizer_B
|
199 |
+
elif model_size == "openai-ViT-B-16":
|
200 |
+
key = "ViT-B-16"
|
201 |
+
model = model_base
|
202 |
+
preprocess_val = preprocess_base
|
203 |
+
tokenizer = tokenizer_base
|
204 |
+
else:
|
205 |
+
return pd.DataFrame({"Error": ["Invalid model size"]})
|
206 |
+
|
207 |
+
path, cumulative_score = map_taxonomy(
|
208 |
+
base_image=image,
|
209 |
+
taxonomy=AMAZON_TAXONOMY,
|
210 |
+
model=model,
|
211 |
+
tokenizer=tokenizer,
|
212 |
+
preprocess_val=preprocess_val,
|
213 |
+
cache_key=key,
|
214 |
+
beam_width=beam_width,
|
215 |
+
)
|
216 |
+
|
217 |
+
output = []
|
218 |
+
for idx, (category, score) in enumerate(path):
|
219 |
+
level = idx + 1
|
220 |
+
output.append({"Level": level, "Category": category, "Score": score})
|
221 |
+
|
222 |
+
df = pd.DataFrame(output)
|
223 |
+
return df
|
224 |
+
|
225 |
+
|
226 |
+
with gr.Blocks() as demo:
|
227 |
+
gr.Markdown("# Image Classification with Taxonomy Mapping")
|
228 |
+
gr.Markdown(
|
229 |
+
"## How to use this app\n\nThis app compares Marqo's E-commerce embeddings to OpenAI's ViT-B-16 CLIP model for E-commerce taxonomy mapping. A beam search is used to find the correct classification in the taxonomy. The original OpenAI CLIP models perform very poorly on E-commerce data."
|
230 |
+
)
|
231 |
+
gr.Markdown(
|
232 |
+
"Upload an image, provide an image URL, or select an example image, select the model size, and get the taxonomy mapping. The taxonomy is based on the Amazon product taxonomy."
|
233 |
+
)
|
234 |
+
|
235 |
+
with gr.Row():
|
236 |
+
with gr.Column():
|
237 |
+
image_input = gr.Image(type="pil", label="Upload Image", height=300)
|
238 |
+
image_url_input = gr.Textbox(
|
239 |
+
lines=1, placeholder="Image URL", label="Image URL"
|
240 |
+
)
|
241 |
+
gr.Markdown("### Or select an example image:")
|
242 |
+
# Get example images from 'images' folder
|
243 |
+
example_images_folder = "images"
|
244 |
+
example_image_paths = [
|
245 |
+
os.path.join(example_images_folder, img)
|
246 |
+
for img in os.listdir(example_images_folder)
|
247 |
+
]
|
248 |
+
gr.Examples(
|
249 |
+
examples=[[img_path] for img_path in example_image_paths],
|
250 |
+
inputs=image_input,
|
251 |
+
label="Example Images",
|
252 |
+
examples_per_page=100,
|
253 |
+
)
|
254 |
+
with gr.Column():
|
255 |
+
model_size_input = gr.Radio(
|
256 |
+
choices=[
|
257 |
+
"marqo-ecommerce-embeddings-L",
|
258 |
+
"marqo-ecommerce-embeddings-B",
|
259 |
+
"openai-ViT-B-16",
|
260 |
+
],
|
261 |
+
label="Model",
|
262 |
+
value="marqo-ecommerce-embeddings-L",
|
263 |
+
)
|
264 |
+
beam_width_input = gr.Number(
|
265 |
+
label="Beam Width", value=5, minimum=1, step=1
|
266 |
+
)
|
267 |
+
classify_button = gr.Button("Classify")
|
268 |
+
output_table = gr.Dataframe(headers=["Level", "Category", "Score"])
|
269 |
+
|
270 |
+
classify_button.click(
|
271 |
+
fn=classify_image,
|
272 |
+
inputs=[image_input, image_url_input, model_size_input, beam_width_input],
|
273 |
+
outputs=output_table,
|
274 |
+
)
|
275 |
+
|
276 |
+
demo.launch()
|
cache_taxonomy_vectors.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
import open_clip
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
+
if device == "cpu":
|
8 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
9 |
+
|
10 |
+
|
11 |
+
def generate_cache(texts: list[str], model_name: str, batch_size: int = 16) -> dict:
|
12 |
+
model, _, _ = open_clip.create_model_and_transforms(model_name, device=device)
|
13 |
+
tokenizer = open_clip.get_tokenizer(model_name)
|
14 |
+
|
15 |
+
cache = {}
|
16 |
+
|
17 |
+
for i in tqdm(range(0, len(texts), batch_size)):
|
18 |
+
batch = texts[i : i + batch_size]
|
19 |
+
tokens = tokenizer(batch).to(device)
|
20 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
21 |
+
embeddings = model.encode_text(tokens, normalize=True).cpu().numpy()
|
22 |
+
for text, embedding in zip(batch, embeddings):
|
23 |
+
cache[text] = embedding.tolist()
|
24 |
+
|
25 |
+
return cache
|
26 |
+
|
27 |
+
|
28 |
+
def flatten_taxonomy(taxonomy: dict) -> list[str]:
|
29 |
+
classes = []
|
30 |
+
for key, value in taxonomy.items():
|
31 |
+
classes.append(key)
|
32 |
+
if isinstance(value, dict):
|
33 |
+
classes.extend(flatten_taxonomy(value))
|
34 |
+
if isinstance(value, list):
|
35 |
+
classes.extend(value)
|
36 |
+
return classes
|
37 |
+
|
38 |
+
|
39 |
+
def main():
|
40 |
+
models = [
|
41 |
+
"hf-hub:Marqo/marqo-ecommerce-embeddings-B",
|
42 |
+
"hf-hub:Marqo/marqo-ecommerce-embeddings-L",
|
43 |
+
"ViT-B-16"
|
44 |
+
]
|
45 |
+
|
46 |
+
with open("amazon.json") as f:
|
47 |
+
taxonomy = json.load(f)
|
48 |
+
print("Loaded taxonomy")
|
49 |
+
|
50 |
+
print("Flattening taxonomy")
|
51 |
+
texts = flatten_taxonomy(taxonomy)
|
52 |
+
|
53 |
+
print("Generating cache")
|
54 |
+
for model in models:
|
55 |
+
cache = generate_cache(texts, model)
|
56 |
+
with open(f'{model.split("/")[-1]}.json', "w+") as f:
|
57 |
+
json.dump(cache, f)
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
main()
|
images/bike-helmet.png
ADDED
images/coffee.png
ADDED
images/cooking-book.jpg
ADDED
images/cutting-board.png
ADDED
images/flip-flops.jpg
ADDED
images/grater.png
ADDED
images/green-shirt.webp
ADDED
images/hoop-earring.jpg
ADDED
images/iron.png
ADDED
images/laptop.png
ADDED
images/notebook.png
ADDED
images/red-dress.webp
ADDED
images/runners.png
ADDED
images/sleeping-bag.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
Pillow
|
4 |
+
gradio
|
5 |
+
ftfy
|
6 |
+
open_clip_torch
|