Spaces:
Running
Running
import streamlit as st | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
from bs4 import BeautifulSoup | |
import pandas as pd | |
import os | |
def download_model(model_url, model_path): | |
if not os.path.exists(model_path): | |
response = requests.get(model_url) | |
with open(model_path, 'wb') as f: | |
f.write(response.content) | |
def load_model(model_path): | |
interpreter = tf.lite.Interpreter(model_path=model_path) | |
interpreter.allocate_tensors() | |
return interpreter | |
def preprocess_image(image, input_size): | |
image = image.convert('RGB') | |
image = image.resize((input_size, input_size)) | |
image_np = np.array(image, dtype=np.float32) | |
image_np = np.expand_dims(image_np, axis=0) | |
image_np = image_np / 255.0 # Normalize to [0, 1] | |
return image_np | |
def run_inference(interpreter, input_data): | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
interpreter.set_tensor(input_details[0]['index'], input_data) | |
interpreter.invoke() | |
output_data_shopping_intent = interpreter.get_tensor(output_details[0]['index']) | |
return output_data_shopping_intent | |
def fetch_images_from_url(url): | |
response = requests.get(url) | |
soup = BeautifulSoup(response.content, 'html.parser') | |
img_tags = soup.find_all('img') | |
img_urls = [img['src'] for img in img_tags if 'src' in img.attrs] | |
return img_urls | |
def render_intent_bars(labels, percentages): | |
for label, percentage in zip(labels, percentages): | |
bar_html = f""" | |
<div style='display: flex; align-items: center;'> | |
<div style='width: 30%; text-align: right; padding-right: 10px;'>{label}</div> | |
<div style='width: 70%; display: flex; align-items: center;'> | |
<div style='background-color: #007BFF; height: 10px; width: {percentage}%;'></div> | |
<div style='padding-left: 10px;'>{percentage:.2f}%</div> | |
</div> | |
</div> | |
""" | |
st.markdown(bar_html, unsafe_allow_html=True) | |
def main(): | |
st.set_page_config(layout="wide") | |
st.title("Shopping Intent Classification - SEO by DEJAN") | |
st.markdown(""" | |
Multi-label image classification model [extracted from Chrome](https://dejanmarketing.com/product-image-optimisation-with-chromes-convolutional-neural-network/). The model can be deployed in an automated pipeline capable of classifying product images in bulk. Javascript-based website scraping currently unsupported. | |
""") | |
st.write("Enter a URL to fetch and classify all images on the page:") | |
model_url = "https://huggingface.co/dejanseo/shopping-intent/resolve/main/model.tflite" | |
model_path = "model.tflite" | |
download_model(model_url, model_path) | |
url = st.text_input("Enter URL") | |
if url: | |
img_urls = fetch_images_from_url(url) | |
if img_urls: | |
st.write(f"Found {len(img_urls)} images") | |
interpreter = load_model(model_path) | |
input_details = interpreter.get_input_details() | |
input_shape = input_details[0]['shape'] | |
input_size = input_shape[1] # assuming square input | |
categories = [ | |
"No Shopping Intent", | |
"Apparel", | |
"Home Decor", | |
"Other" | |
] | |
for img_url in img_urls: | |
try: | |
response = requests.get(img_url) | |
image = Image.open(BytesIO(response.content)) | |
input_data = preprocess_image(image, input_size) | |
output_data_shopping_intent = run_inference(interpreter, input_data) | |
shopping_intent_percentages = (output_data_shopping_intent.flatten() * 100).tolist() | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
st.image(image.resize((224, 224)), width=224) | |
with col2: | |
st.write(f"[URL]({img_url})") | |
render_intent_bars(categories, shopping_intent_percentages) | |
st.write("---") | |
except Exception as e: | |
st.write(f"Could not process image {img_url}: {e}") | |
st.markdown(""" | |
Interested in using this in an automated pipeline for bulk image classification? | |
Please [book an appointment](https://dejanmarketing.com/conference/) to discuss your needs. | |
""") | |
if __name__ == "__main__": | |
main() | |