|
import gradio as gr |
|
import open_clip |
|
import torch |
|
import requests |
|
import numpy as np |
|
from PIL import Image |
|
from io import BytesIO |
|
from items import ecommerce_items |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
sidebar_markdown = """ |
|
|
|
Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added. |
|
|
|
## Documentation |
|
π [Blog Post]() |
|
|
|
π [Use Case Blog Post]() |
|
|
|
## Code |
|
π» [GitHub Repo]() |
|
|
|
π€ [Google Colab]() |
|
|
|
π€ [Hugging Face Collection]() |
|
|
|
## Citation |
|
If you use Marqo-Ecommerce-L or Marqo-Ecommerce-B, please cite us: |
|
``` |
|
|
|
``` |
|
""" |
|
|
|
from huggingface_hub import login |
|
|
|
|
|
api_key = os.getenv("HF_API_TOKEN") |
|
|
|
if api_key is None: |
|
raise ValueError("Hugging Face API key not found. Please set the 'HF_API_TOKEN' environment variable.") |
|
|
|
|
|
login(token=api_key) |
|
|
|
|
|
def load_model(progress=gr.Progress()): |
|
progress(0, "Initializing model...") |
|
model_name = 'hf-hub:Marqo/marqo-ecommerce-embeddings-B' |
|
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(model_name) |
|
|
|
progress(0.5, "Loading tokenizer...") |
|
tokenizer = open_clip.get_tokenizer(model_name) |
|
|
|
text = tokenizer(ecommerce_items) |
|
|
|
progress(0.75, "Encoding text features...") |
|
with torch.no_grad(), torch.amp.autocast('cuda'): |
|
text_features = model.encode_text(text) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
progress(1.0, "Model loaded successfully!") |
|
|
|
return model, preprocess_val, text_features |
|
|
|
|
|
model, preprocess_val, text_features = load_model() |
|
|
|
|
|
def predict(image, url): |
|
if url: |
|
response = requests.get(url) |
|
image = Image.open(BytesIO(response.content)) |
|
|
|
processed_image = preprocess_val(image).unsqueeze(0) |
|
|
|
with torch.no_grad(), torch.amp.autocast('cuda'): |
|
image_features = model.encode_image(processed_image) |
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
|
|
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1) |
|
|
|
sorted_confidences = sorted( |
|
{ecommerce_items[i]: float(text_probs[0, i]) for i in range(len(ecommerce_items))}.items(), |
|
key=lambda x: x[1], |
|
reverse=True |
|
) |
|
top_10_confidences = dict(sorted_confidences[:10]) |
|
|
|
return image, top_10_confidences |
|
|
|
|
|
def clear_fields(): |
|
return None, "" |
|
|
|
|
|
title = "Ecommerce Item Classifier with Marqo-Ecommerce Embedding Models" |
|
description = "Upload an image or provide a URL of an ecommerce item to classify it using Marqo-Ecommerce Models!" |
|
|
|
examples = [ |
|
["images/laptop.png", "Laptop"], |
|
["images/grater.png", "Grater"], |
|
["images/flip-flops.jpg", "Flip Flops"], |
|
["images/bike-helmet.png", "Bike Helmet"], |
|
["images/sleeping-bag.png", "Sleeping Bag"], |
|
["images/cutting-board.png", "Cutting Board"], |
|
["images/iron.png", "Iron"], |
|
["images/coffee.png", "Coffee"], |
|
] |
|
|
|
with gr.Blocks(css=""" |
|
.remove-btn { |
|
font-size: 24px !important; /* Increase the font size of the cross button */ |
|
line-height: 24px !important; |
|
width: 30px !important; /* Increase the width */ |
|
height: 30px !important; /* Increase the height */ |
|
} |
|
""") as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown(f"# {title}") |
|
gr.Markdown(description) |
|
gr.Markdown(sidebar_markdown) |
|
gr.Markdown(" ", elem_id="vertical-line") |
|
with gr.Column(scale=2): |
|
input_image = gr.Image(type="pil", label="Upload Ecommerce Item Image", height=312) |
|
input_url = gr.Textbox(label="Or provide an image URL") |
|
with gr.Row(): |
|
predict_button = gr.Button("Classify") |
|
clear_button = gr.Button("Clear") |
|
gr.Markdown("Or click on one of the images below to classify it:") |
|
gr.Examples(examples=examples, inputs=input_image) |
|
output_label = gr.Label(num_top_classes=6) |
|
predict_button.click(predict, inputs=[input_image, input_url], outputs=[input_image, output_label]) |
|
clear_button.click(clear_fields, outputs=[input_image, input_url]) |
|
|
|
|
|
demo.launch() |
|
|