elliesleightholm's picture
updating for zero gpu
86b24b1
raw
history blame
3.16 kB
import spaces
import marqo
import requests
import io
from PIL import Image
import gradio as gr
import os
from dotenv import load_dotenv
load_dotenv()
# Initialize Marqo client (for local deployment)
# mq = marqo.Client("http://localhost:8882", api_key=None)
# Initialize Marqo client (for Marqo Cloud deployment)
api_key = os.getenv("MARQO_API_KEY")
mq = marqo.Client("https://api.marqo.ai", api_key=api_key)
def search_marqo(query, themes, negatives):
# Build query weights
query_weights = {query: 1.0}
if themes:
query_weights[themes] = 0.75
if negatives:
query_weights[negatives] = -1.1
# Perform search with Marqo
res = mq.index("marqo-ecommerce-b").search(query_weights, limit=10) # limit to top 10 results
# Prepare results
products = []
for hit in res['hits']:
image_url = hit.get('image_url')
title = hit.get('title', 'No Title')
description = hit.get('description', 'No Description')
price = hit.get('price', 'N/A')
score = hit['_score']
# Fetch the image from the URL
response = requests.get(image_url)
image = Image.open(io.BytesIO(response.content))
# Append product details for Gradio display
product_info = f'{title}\n{description}\nPrice: {price}\nScore: {score:.4f}'
products.append((image, product_info))
return products
# Function to clear inputs and results
def clear_inputs():
return "", "", [], []
# Gradio Blocks Interface for Custom Layout
with gr.Blocks(css=".orange-button { background-color: orange; color: black; }") as interface:
gr.Markdown("<h1 style='text-align: center;'>Multimodal Ecommerce Search with Marqo's SOTA Embedding Models</h1>")
gr.Markdown("### This ecommerce search demo uses:")
gr.Markdown("### 1. [Marqo Cloud](https://www.marqo.ai/cloud) for the Search Engine.")
gr.Markdown("### 2. [Marqo-Ecommerce-Embeddings](https://huggingface.co/collections/Marqo/marqo-ecommerce-embeddings-66f611b9bb9d035a8d164fbb) for the multimodal embedding model.")
gr.Markdown("### 3. 100k products from the [Marqo-GS-10M](https://huggingface.co/datasets/Marqo/marqo-GS-10M) dataset.")
gr.Markdown("")
with gr.Row():
query_input = gr.Textbox(placeholder="Coffee machine", label="Search Query")
themes_input = gr.Textbox(placeholder="Silver", label="More of...")
negatives_input = gr.Textbox(placeholder="Buttons", label="Less of...")
with gr.Row():
search_button = gr.Button("Submit", elem_classes="orange-button")
results_gallery = gr.Gallery(label="Top 10 Results", columns=4)
search_button.click(fn=search_marqo, inputs=[query_input, themes_input, negatives_input], outputs=results_gallery)
query_input.submit(fn=search_marqo, inputs=[query_input, themes_input, negatives_input], outputs=results_gallery)
themes_input.submit(fn=search_marqo, inputs=[query_input, themes_input, negatives_input], outputs=results_gallery)
negatives_input.submit(fn=search_marqo, inputs=[query_input, themes_input, negatives_input], outputs=results_gallery)
interface.launch()