image_recommender_2 / handler.py
clfegg's picture
Update handler.py
c704ba4 verified
import os
current_dir = os.getcwd()
os.environ['HF_HOME'] = os.path.join(current_dir)
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from PIL import Image
from serpapi import GoogleSearch
from keybert import KeyBERT
from typing import Dict, Any, List
import base64
import torch
model_id = "vikhyatk/moondream2"
revision = "2024-08-26"
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=revision
)
model.to('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
sentence_model = SentenceTransformer(model_name, device='cuda')
class ProductSearcher:
def __init__(self, user_input, image_path):
self.user_input = user_input
self.image_path = image_path
self.predefined_questions = [
"tôi muốn mua sản phẩm này",
"tôi muốn thông tin về sản phẩm",
"tôi muốn biết giá cái này"
]
self.prompts = [
"Descibe product in image with it color. Only answer in one sentence"
"Describe the product in detail and provide information about the product. If you don't know the product, you can describe the image",
"Estimate the price of the product and provide a detailed description of the product"
]
self.description = ''
self.keyphrases = []
self.kw_model= KeyBERT()
def get_most_similar_sentence(self):
user_input_embedding = sentence_model.encode(self.user_input)
predefined_embeddings = sentence_model.encode(self.predefined_questions)
similarity_scores = util.pytorch_cos_sim(user_input_embedding, predefined_embeddings)
most_similar_index = similarity_scores.argmax().item()
return self.prompts[most_similar_index]
def generate_description(self):
prompt = self.get_most_similar_sentence()
image = Image.open(self.image_path)
enc_image = model.encode_image(image)
self.description = model.answer_question(enc_image, prompt, tokenizer)
del enc_image
def extract_keyphrases(self):
self.keyphrases = self.kw_model.extract_keywords(self.description)
def search_products(self, k=3):
# Concatenate keyphrases to form a question
q = [keyword[0] for keyword in self.keyphrases if keyword[0] != 'image']
question = " ".join(q)
search = GoogleSearch({
"engine": "google",
# "q": self.keyphrases[0]['word'],
"q":question,
"tbm": "shop",
"api_key": os.environ["API_KEY"]
})
results = search.get_dict()
# Extract top k products from the search results
products = results.get('shopping_results', [])[:k]
return products
def run(self, k=3):
self.generate_description()
self.extract_keyphrases()
results = self.search_products(k)
return results
class EndpointHandler:
def __init__(self,path=""):
pass
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: dict): A dictionary containing the inputs.
message (:obj: str): The user message.
image (:obj: str): The base64-encoded image content.
Return:
A list of dictionaries containing the product search results.
"""
inputs = data.get("inputs", {})
message = inputs.get("message")
image_content = inputs.get("image")
# Decode the base64-encoded image content
image_bytes = base64.b64decode(image_content)
# Save the image to a temporary file
image_path = "input/temp_image.jpg"
os.makedirs("input", exist_ok=True)
with open(image_path, "wb") as f:
f.write(image_bytes)
# Initialize ProductSearcher with the message and image path
searcher = ProductSearcher(message, image_path)
# Run the search and get results
results = searcher.run(k=3)
del searcher
# Return the search results
return results