elliesleightholm's picture
updating text
702d942
raw
history blame
4.7 kB
import gradio as gr
import open_clip
import torch
import requests
import numpy as np
from PIL import Image
from io import BytesIO
from items import ecommerce_items
import os
from dotenv import load_dotenv
# Load environment variables from the .env file
load_dotenv()
# Sidebar content
sidebar_markdown = """
Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added.
## Documentation
πŸ“š [Blog Post]()
πŸ“ [Use Case Blog Post]()
## Code
πŸ’» [GitHub Repo]()
🀝 [Google Colab]()
πŸ€— [Hugging Face Collection]()
## Citation
If you use Marqo-Ecommerce-L or Marqo-Ecommerce-B, please cite us:
```
```
"""
from huggingface_hub import login
# Get your Hugging Face API key (ensure it is set in your environment variables)
api_key = os.getenv("HF_API_TOKEN")
if api_key is None:
raise ValueError("Hugging Face API key not found. Please set the 'HF_API_TOKEN' environment variable.")
# Login using the token
login(token=api_key)
# Initialize the model and tokenizer
def load_model(progress=gr.Progress()):
progress(0, "Initializing model...")
model_name = 'hf-hub:Marqo/marqo-ecommerce-embeddings-B'
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(model_name)
progress(0.5, "Loading tokenizer...")
tokenizer = open_clip.get_tokenizer(model_name)
text = tokenizer(ecommerce_items)
progress(0.75, "Encoding text features...")
with torch.no_grad(), torch.amp.autocast('cuda'):
text_features = model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)
progress(1.0, "Model loaded successfully!")
return model, preprocess_val, text_features
# Load model and prepare interface
model, preprocess_val, text_features = load_model()
# Prediction function
def predict(image, url):
if url:
response = requests.get(url)
image = Image.open(BytesIO(response.content))
processed_image = preprocess_val(image).unsqueeze(0)
with torch.no_grad(), torch.amp.autocast('cuda'):
image_features = model.encode_image(processed_image)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
sorted_confidences = sorted(
{ecommerce_items[i]: float(text_probs[0, i]) for i in range(len(ecommerce_items))}.items(),
key=lambda x: x[1],
reverse=True
)
top_10_confidences = dict(sorted_confidences[:10])
return image, top_10_confidences
# Clear function
def clear_fields():
return None, ""
# Gradio interface
title = "Ecommerce Item Classifier with Marqo-Ecommerce Embedding Models"
description = "Upload an image or provide a URL of an ecommerce item to classify it using Marqo-Ecommerce Models!"
examples = [
["images/laptop.png", "Laptop"],
["images/grater.png", "Grater"],
["images/flip-flops.jpg", "Flip Flops"],
["images/bike-helmet.png", "Bike Helmet"],
["images/sleeping-bag.png", "Sleeping Bag"],
["images/cutting-board.png", "Cutting Board"],
["images/iron.png", "Iron"],
["images/coffee.png", "Coffee"],
]
with gr.Blocks(css="""
.remove-btn {
font-size: 24px !important; /* Increase the font size of the cross button */
line-height: 24px !important;
width: 30px !important; /* Increase the width */
height: 30px !important; /* Increase the height */
}
""") as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"# {title}")
gr.Markdown(description)
gr.Markdown(sidebar_markdown)
gr.Markdown(" ", elem_id="vertical-line") # Add an empty Markdown with a custom ID
with gr.Column(scale=2):
input_image = gr.Image(type="pil", label="Upload Ecommerce Item Image", height=312)
input_url = gr.Textbox(label="Or provide an image URL")
with gr.Row():
predict_button = gr.Button("Classify")
clear_button = gr.Button("Clear")
gr.Markdown("Or click on one of the images below to classify it:")
gr.Examples(examples=examples, inputs=input_image)
output_label = gr.Label(num_top_classes=6)
predict_button.click(predict, inputs=[input_image, input_url], outputs=[input_image, output_label])
clear_button.click(clear_fields, outputs=[input_image, input_url])
# Launch the interface
demo.launch()