Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,048 Bytes
0f3bf2c d09d79e 9467a9f d09d79e 95052d1 d09d79e fe26b6e d09d79e fe26b6e d09d79e fe26b6e d09d79e fe26b6e d09d79e fe26b6e d09d79e fe26b6e d09d79e fe26b6e ab0c57f fe26b6e d09d79e fe26b6e d09d79e fe26b6e d09d79e fe26b6e d09d79e 0f3bf2c fe26b6e d09d79e 0f3bf2c d09d79e 702d942 d09d79e 702d942 d09d79e fe26b6e d09d79e fe26b6e d09d79e fe26b6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import spaces
import gradio as gr
import open_clip
import torch
import requests
import numpy as np
from PIL import Image
from io import BytesIO
from items import ecommerce_items
import os
# from dotenv import load_dotenv
# Load environment variables from the .env file
# load_dotenv()
# Sidebar content
sidebar_markdown = """
Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added.
## Documentation
π [Blog Post](https://www.marqo.ai/blog/introducing-marqos-ecommerce-embedding-models)
π [Classification Use Case Blog Post](https://www.marqo.ai/blog/ecommerce-image-classification-with-huggingface-transformers)
π [Image Search Use Case Blog Post](https://www.marqo.ai/blog/how-to-build-an-ecommerce-image-search-application)
## Code
π» [GitHub Repo](https://github.com/marqo-ai/marqo-ecommerce-embeddings)
π€ [Google Colab](https://colab.research.google.com/drive/1ctqDrXs_P-RIOPc9xcUF83WLdYQ0wf-8?usp=sharing)
π€ [Hugging Face Collection](https://huggingface.co/collections/Marqo/marqo-ecommerce-embeddings-66f611b9bb9d035a8d164fbb)
## Citation
If you use Marqo-Ecommerce-L or Marqo-Ecommerce-B, please cite us:
```
@software{zhu2024marqoecommembed_2024,
author = {Tianyu Zhu and and Jesse Clark},
month = oct,
title = {{Marqo Ecommerce Embeddings - Foundation Model for Product Embeddings}},
url = {https://github.com/marqo-ai/marqo-ecommerce-embeddings/},
version = {1.0.0},
year = {2024}
}
```
"""
# Function to initialize a model, preprocess, and text features
@spaces.GPU
def initialize_model(model_name, progress=gr.Progress()):
progress(0, f"Initializing model: {model_name}...")
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(f"hf-hub:Marqo/{model_name}")
progress(0.5, "Loading tokenizer...")
tokenizer = open_clip.get_tokenizer(f"hf-hub:Marqo/{model_name}")
text = tokenizer(ecommerce_items)
progress(0.75, "Encoding text features...")
with torch.no_grad(), torch.amp.autocast('cuda'):
text_features = model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)
progress(1.0, f"Model {model_name} loaded successfully!")
return model, preprocess_val, text_features
# Load L model first, followed by B model
progress_bar = gr.Progress()
model_l, preprocess_val_l, text_features_l = initialize_model("marqo-ecommerce-embeddings-L", progress=progress_bar)
model_b, preprocess_val_b, text_features_b = initialize_model("marqo-ecommerce-embeddings-B", progress=progress_bar)
# Prediction function
@spaces.GPU
def predict(image, url, model_name):
if model_name == "marqo-ecommerce-embeddings-B":
model, preprocess_val, text_features = model_b, preprocess_val_b, text_features_b
else:
model, preprocess_val, text_features = model_l, preprocess_val_l, text_features_l
if url:
response = requests.get(url)
image = Image.open(BytesIO(response.content))
processed_image = preprocess_val(image).unsqueeze(0)
with torch.no_grad(), torch.amp.autocast('cuda'):
image_features = model.encode_image(processed_image)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
sorted_confidences = sorted(
{ecommerce_items[i]: float(text_probs[0, i]) for i in range(len(ecommerce_items))}.items(),
key=lambda x: x[1],
reverse=True
)
top_10_confidences = dict(sorted_confidences[:10])
return image, top_10_confidences
# Clear function
@spaces.GPU
def clear_fields():
return None, ""
# Gradio interface
title = "Ecommerce Item Classifier with Marqo-Ecommerce Embedding Models"
description = "Upload an image or provide a URL of an ecommerce item to classify it using Marqo-Ecommerce Models!"
examples = [
["images/laptop.png", "Laptop"],
["images/grater.png", "Grater"],
["images/flip-flops.jpg", "Flip Flops"],
["images/bike-helmet.png", "Bike Helmet"],
["images/sleeping-bag.png", "Sleeping Bag"],
["images/cutting-board.png", "Cutting Board"],
["images/iron.png", "Iron"],
["images/coffee.png", "Coffee"],
]
with gr.Blocks(css="""
.remove-btn {
font-size: 24px !important; /* Increase the font size of the cross button */
line-height: 24px !important;
width: 30px !important; /* Increase the width */
height: 30px !important; /* Increase the height */
}
""") as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"# {title}")
gr.Markdown(description)
gr.Markdown(sidebar_markdown)
gr.Markdown(" ", elem_id="vertical-line") # Add an empty Markdown with a custom ID
with gr.Column(scale=2):
input_image = gr.Image(type="pil", label="Upload Ecommerce Item Image", height=312)
input_url = gr.Textbox(label="Or provide an image URL")
model_selector = gr.Dropdown(
choices=["marqo-ecommerce-embeddings-L", "marqo-ecommerce-embeddings-B"],
value="marqo-ecommerce-embeddings-L",
label="Select Model"
)
with gr.Row():
predict_button = gr.Button("Classify")
clear_button = gr.Button("Clear")
gr.Markdown("Or click on one of the images below to classify it:")
gr.Examples(examples=examples, inputs=input_image)
output_label = gr.Label(num_top_classes=6)
predict_button.click(predict, inputs=[input_image, input_url, model_selector], outputs=[input_image, output_label])
clear_button.click(clear_fields, outputs=[input_image, input_url, model_selector])
# Launch the interface
demo.launch() |