File size: 6,048 Bytes
0f3bf2c
d09d79e
 
 
 
 
 
 
 
 
9467a9f
d09d79e
95052d1
 
 
d09d79e
 
 
 
 
 
fe26b6e
d09d79e
fe26b6e
 
 
d09d79e
 
fe26b6e
d09d79e
fe26b6e
d09d79e
fe26b6e
d09d79e
 
 
 
fe26b6e
 
 
 
 
 
 
 
d09d79e
 
 
fe26b6e
ab0c57f
fe26b6e
 
 
d09d79e
 
fe26b6e
d09d79e
 
 
 
 
 
 
 
fe26b6e
d09d79e
 
 
fe26b6e
 
 
 
d09d79e
 
0f3bf2c
fe26b6e
 
 
 
 
 
d09d79e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f3bf2c
d09d79e
 
 
 
 
702d942
d09d79e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702d942
d09d79e
fe26b6e
 
 
 
 
d09d79e
 
 
 
 
 
fe26b6e
 
d09d79e
 
fe26b6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import spaces
import gradio as gr
import open_clip
import torch
import requests
import numpy as np
from PIL import Image
from io import BytesIO
from items import ecommerce_items
import os
# from dotenv import load_dotenv

# Load environment variables from the .env file
# load_dotenv()

# Sidebar content
sidebar_markdown = """

Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added.

## Documentation
πŸ“š [Blog Post](https://www.marqo.ai/blog/introducing-marqos-ecommerce-embedding-models)

πŸ“ [Classification Use Case Blog Post](https://www.marqo.ai/blog/ecommerce-image-classification-with-huggingface-transformers)

πŸ”Ž [Image Search Use Case Blog Post](https://www.marqo.ai/blog/how-to-build-an-ecommerce-image-search-application)

## Code
πŸ’» [GitHub Repo](https://github.com/marqo-ai/marqo-ecommerce-embeddings)

🀝 [Google Colab](https://colab.research.google.com/drive/1ctqDrXs_P-RIOPc9xcUF83WLdYQ0wf-8?usp=sharing)

πŸ€— [Hugging Face Collection](https://huggingface.co/collections/Marqo/marqo-ecommerce-embeddings-66f611b9bb9d035a8d164fbb)

## Citation
If you use Marqo-Ecommerce-L or Marqo-Ecommerce-B, please cite us:
```
@software{zhu2024marqoecommembed_2024,
    author = {Tianyu Zhu and and Jesse Clark},
    month = oct,
    title = {{Marqo Ecommerce Embeddings - Foundation Model for Product Embeddings}},
    url = {https://github.com/marqo-ai/marqo-ecommerce-embeddings/},
    version = {1.0.0},
    year = {2024}
    }
```
"""

# Function to initialize a model, preprocess, and text features
@spaces.GPU
def initialize_model(model_name, progress=gr.Progress()):
    progress(0, f"Initializing model: {model_name}...")
    model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(f"hf-hub:Marqo/{model_name}")
    
    progress(0.5, "Loading tokenizer...")
    tokenizer = open_clip.get_tokenizer(f"hf-hub:Marqo/{model_name}")
    
    text = tokenizer(ecommerce_items)
    
    progress(0.75, "Encoding text features...")
    with torch.no_grad(), torch.amp.autocast('cuda'):
        text_features = model.encode_text(text)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    
    progress(1.0, f"Model {model_name} loaded successfully!")
    
    return model, preprocess_val, text_features

# Load L model first, followed by B model
progress_bar = gr.Progress()
model_l, preprocess_val_l, text_features_l = initialize_model("marqo-ecommerce-embeddings-L", progress=progress_bar)
model_b, preprocess_val_b, text_features_b = initialize_model("marqo-ecommerce-embeddings-B", progress=progress_bar)

# Prediction function
@spaces.GPU
def predict(image, url, model_name):
    if model_name == "marqo-ecommerce-embeddings-B":
        model, preprocess_val, text_features = model_b, preprocess_val_b, text_features_b
    else:
        model, preprocess_val, text_features = model_l, preprocess_val_l, text_features_l
    
    if url:
        response = requests.get(url)
        image = Image.open(BytesIO(response.content))
    
    processed_image = preprocess_val(image).unsqueeze(0)

    with torch.no_grad(), torch.amp.autocast('cuda'):
        image_features = model.encode_image(processed_image)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
        
        sorted_confidences = sorted(
            {ecommerce_items[i]: float(text_probs[0, i]) for i in range(len(ecommerce_items))}.items(), 
            key=lambda x: x[1], 
            reverse=True
        )
        top_10_confidences = dict(sorted_confidences[:10])
        
    return image, top_10_confidences

# Clear function
@spaces.GPU
def clear_fields():
    return None, ""

# Gradio interface
title = "Ecommerce Item Classifier with Marqo-Ecommerce Embedding Models"
description = "Upload an image or provide a URL of an ecommerce item to classify it using Marqo-Ecommerce Models!"

examples = [
    ["images/laptop.png", "Laptop"],
    ["images/grater.png", "Grater"],
    ["images/flip-flops.jpg", "Flip Flops"],
    ["images/bike-helmet.png", "Bike Helmet"],
    ["images/sleeping-bag.png", "Sleeping Bag"],
    ["images/cutting-board.png", "Cutting Board"],
    ["images/iron.png", "Iron"],
    ["images/coffee.png", "Coffee"],
]

with gr.Blocks(css="""
    .remove-btn {
        font-size: 24px !important; /* Increase the font size of the cross button */
        line-height: 24px !important;
        width: 30px !important; /* Increase the width */
        height: 30px !important; /* Increase the height */
    }
""") as demo:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(f"# {title}")
            gr.Markdown(description)
            gr.Markdown(sidebar_markdown)
            gr.Markdown(" ", elem_id="vertical-line")  # Add an empty Markdown with a custom ID
        with gr.Column(scale=2):
            input_image = gr.Image(type="pil", label="Upload Ecommerce Item Image", height=312)
            input_url = gr.Textbox(label="Or provide an image URL")
            model_selector = gr.Dropdown(
                choices=["marqo-ecommerce-embeddings-L", "marqo-ecommerce-embeddings-B"],
                value="marqo-ecommerce-embeddings-L",
                label="Select Model"
            )
            with gr.Row():
                predict_button = gr.Button("Classify")
                clear_button = gr.Button("Clear")
            gr.Markdown("Or click on one of the images below to classify it:")
            gr.Examples(examples=examples, inputs=input_image)
            output_label = gr.Label(num_top_classes=6)
            predict_button.click(predict, inputs=[input_image, input_url, model_selector], outputs=[input_image, output_label])
            clear_button.click(clear_fields, outputs=[input_image, input_url, model_selector])

# Launch the interface
demo.launch()