import os current_dir = os.getcwd() os.environ['HF_HOME'] = os.path.join(current_dir) from sentence_transformers import SentenceTransformer, util from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from PIL import Image from serpapi import GoogleSearch from keybert import KeyBERT from typing import Dict, Any, List import base64 import torch model_id = "vikhyatk/moondream2" revision = "2024-08-26" model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, revision=revision ) model.to('cuda') tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" sentence_model = SentenceTransformer(model_name, device='cuda') class ProductSearcher: def __init__(self, user_input, image_path): self.user_input = user_input self.image_path = image_path self.predefined_questions = [ "tôi muốn mua sản phẩm này", "tôi muốn thông tin về sản phẩm", "tôi muốn biết giá cái này" ] self.prompts = [ "Descibe product in image with it color. Only answer in one sentence" "Describe the product in detail and provide information about the product. If you don't know the product, you can describe the image", "Estimate the price of the product and provide a detailed description of the product" ] self.description = '' self.keyphrases = [] self.kw_model= KeyBERT() def get_most_similar_sentence(self): user_input_embedding = sentence_model.encode(self.user_input) predefined_embeddings = sentence_model.encode(self.predefined_questions) similarity_scores = util.pytorch_cos_sim(user_input_embedding, predefined_embeddings) most_similar_index = similarity_scores.argmax().item() return self.prompts[most_similar_index] def generate_description(self): prompt = self.get_most_similar_sentence() image = Image.open(self.image_path) enc_image = model.encode_image(image) self.description = model.answer_question(enc_image, prompt, tokenizer) del enc_image def extract_keyphrases(self): self.keyphrases = self.kw_model.extract_keywords(self.description) def search_products(self, k=3): # Concatenate keyphrases to form a question q = [keyword[0] for keyword in self.keyphrases if keyword[0] != 'image'] question = " ".join(q) search = GoogleSearch({ "engine": "google", # "q": self.keyphrases[0]['word'], "q":question, "tbm": "shop", "api_key": os.environ["API_KEY"] }) results = search.get_dict() # Extract top k products from the search results products = results.get('shopping_results', [])[:k] return products def run(self, k=3): self.generate_description() self.extract_keyphrases() results = self.search_products(k) return results class EndpointHandler: def __init__(self,path=""): pass def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: dict): A dictionary containing the inputs. message (:obj: str): The user message. image (:obj: str): The base64-encoded image content. Return: A list of dictionaries containing the product search results. """ inputs = data.get("inputs", {}) message = inputs.get("message") image_content = inputs.get("image") # Decode the base64-encoded image content image_bytes = base64.b64decode(image_content) # Save the image to a temporary file image_path = "input/temp_image.jpg" os.makedirs("input", exist_ok=True) with open(image_path, "wb") as f: f.write(image_bytes) # Initialize ProductSearcher with the message and image path searcher = ProductSearcher(message, image_path) # Run the search and get results results = searcher.run(k=3) del searcher # Return the search results return results