Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import open_clip | |
import torch | |
import requests | |
import numpy as np | |
from PIL import Image | |
from io import BytesIO | |
from items import ecommerce_items | |
import os | |
# from dotenv import load_dotenv | |
# Load environment variables from the .env file | |
# load_dotenv() | |
# Sidebar content | |
sidebar_markdown = """ | |
Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added. | |
## Documentation | |
π [Blog Post](https://www.marqo.ai/blog/introducing-marqos-ecommerce-embedding-models) | |
π [Classification Use Case Blog Post](https://www.marqo.ai/blog/ecommerce-image-classification-with-huggingface-transformers) | |
π [Image Search Use Case Blog Post](https://www.marqo.ai/blog/how-to-build-an-ecommerce-image-search-application) | |
## Code | |
π» [GitHub Repo](https://github.com/marqo-ai/marqo-ecommerce-embeddings) | |
π€ [Google Colab](https://colab.research.google.com/drive/1ctqDrXs_P-RIOPc9xcUF83WLdYQ0wf-8?usp=sharing) | |
π€ [Hugging Face Collection](https://huggingface.co/collections/Marqo/marqo-ecommerce-embeddings-66f611b9bb9d035a8d164fbb) | |
## Citation | |
If you use Marqo-Ecommerce-L or Marqo-Ecommerce-B, please cite us: | |
``` | |
@software{zhu2024marqoecommembed_2024, | |
author = {Tianyu Zhu and and Jesse Clark}, | |
month = oct, | |
title = {{Marqo Ecommerce Embeddings - Foundation Model for Product Embeddings}}, | |
url = {https://github.com/marqo-ai/marqo-ecommerce-embeddings/}, | |
version = {1.0.0}, | |
year = {2024} | |
} | |
``` | |
""" | |
# Function to initialize a model, preprocess, and text features | |
def initialize_model(model_name, progress=gr.Progress()): | |
progress(0, f"Initializing model: {model_name}...") | |
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(f"hf-hub:Marqo/{model_name}") | |
progress(0.5, "Loading tokenizer...") | |
tokenizer = open_clip.get_tokenizer(f"hf-hub:Marqo/{model_name}") | |
text = tokenizer(ecommerce_items) | |
progress(0.75, "Encoding text features...") | |
with torch.no_grad(), torch.amp.autocast('cuda'): | |
text_features = model.encode_text(text) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
progress(1.0, f"Model {model_name} loaded successfully!") | |
return model, preprocess_val, text_features | |
# Load L model first, followed by B model | |
progress_bar = gr.Progress() | |
model_l, preprocess_val_l, text_features_l = initialize_model("marqo-ecommerce-embeddings-L", progress=progress_bar) | |
model_b, preprocess_val_b, text_features_b = initialize_model("marqo-ecommerce-embeddings-B", progress=progress_bar) | |
# Prediction function | |
def predict(image, url, model_name): | |
if model_name == "marqo-ecommerce-embeddings-B": | |
model, preprocess_val, text_features = model_b, preprocess_val_b, text_features_b | |
else: | |
model, preprocess_val, text_features = model_l, preprocess_val_l, text_features_l | |
if url: | |
response = requests.get(url) | |
image = Image.open(BytesIO(response.content)) | |
processed_image = preprocess_val(image).unsqueeze(0) | |
with torch.no_grad(), torch.amp.autocast('cuda'): | |
image_features = model.encode_image(processed_image) | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1) | |
sorted_confidences = sorted( | |
{ecommerce_items[i]: float(text_probs[0, i]) for i in range(len(ecommerce_items))}.items(), | |
key=lambda x: x[1], | |
reverse=True | |
) | |
top_10_confidences = dict(sorted_confidences[:10]) | |
return image, top_10_confidences | |
# Clear function | |
def clear_fields(): | |
return None, "" | |
# Gradio interface | |
title = "Ecommerce Item Classifier with Marqo-Ecommerce Embedding Models" | |
description = "Upload an image or provide a URL of an ecommerce item to classify it using Marqo-Ecommerce Models!" | |
examples = [ | |
["images/laptop.png", "Laptop"], | |
["images/grater.png", "Grater"], | |
["images/flip-flops.jpg", "Flip Flops"], | |
["images/bike-helmet.png", "Bike Helmet"], | |
["images/sleeping-bag.png", "Sleeping Bag"], | |
["images/cutting-board.png", "Cutting Board"], | |
["images/iron.png", "Iron"], | |
["images/coffee.png", "Coffee"], | |
] | |
with gr.Blocks(css=""" | |
.remove-btn { | |
font-size: 24px !important; /* Increase the font size of the cross button */ | |
line-height: 24px !important; | |
width: 30px !important; /* Increase the width */ | |
height: 30px !important; /* Increase the height */ | |
} | |
""") as demo: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown(f"# {title}") | |
gr.Markdown(description) | |
gr.Markdown(sidebar_markdown) | |
gr.Markdown(" ", elem_id="vertical-line") # Add an empty Markdown with a custom ID | |
with gr.Column(scale=2): | |
input_image = gr.Image(type="pil", label="Upload Ecommerce Item Image", height=312) | |
input_url = gr.Textbox(label="Or provide an image URL") | |
model_selector = gr.Dropdown( | |
choices=["marqo-ecommerce-embeddings-L", "marqo-ecommerce-embeddings-B"], | |
value="marqo-ecommerce-embeddings-L", | |
label="Select Model" | |
) | |
with gr.Row(): | |
predict_button = gr.Button("Classify") | |
clear_button = gr.Button("Clear") | |
gr.Markdown("Or click on one of the images below to classify it:") | |
gr.Examples(examples=examples, inputs=input_image) | |
output_label = gr.Label(num_top_classes=6) | |
predict_button.click(predict, inputs=[input_image, input_url, model_selector], outputs=[input_image, output_label]) | |
clear_button.click(clear_fields, outputs=[input_image, input_url, model_selector]) | |
# Launch the interface | |
demo.launch() |