|
import os |
|
current_dir = os.getcwd() |
|
os.environ['HF_HOME'] = os.path.join(current_dir) |
|
from sentence_transformers import SentenceTransformer, util |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
from PIL import Image |
|
from serpapi import GoogleSearch |
|
from keybert import KeyBERT |
|
from typing import Dict, Any, List |
|
import base64 |
|
import torch |
|
model_id = "vikhyatk/moondream2" |
|
revision = "2024-08-26" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, trust_remote_code=True, revision=revision |
|
) |
|
|
|
model.to('cuda') |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) |
|
|
|
model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" |
|
sentence_model = SentenceTransformer(model_name, device='cuda') |
|
|
|
class ProductSearcher: |
|
def __init__(self, user_input, image_path): |
|
self.user_input = user_input |
|
self.image_path = image_path |
|
self.predefined_questions = [ |
|
"tôi muốn mua sản phẩm này", |
|
"tôi muốn thông tin về sản phẩm", |
|
"tôi muốn biết giá cái này" |
|
] |
|
self.prompts = [ |
|
"Descibe product in image with it color. Only answer in one sentence" |
|
"Describe the product in detail and provide information about the product. If you don't know the product, you can describe the image", |
|
"Estimate the price of the product and provide a detailed description of the product" |
|
] |
|
self.description = '' |
|
self.keyphrases = [] |
|
self.kw_model= KeyBERT() |
|
|
|
|
|
def get_most_similar_sentence(self): |
|
user_input_embedding = sentence_model.encode(self.user_input) |
|
predefined_embeddings = sentence_model.encode(self.predefined_questions) |
|
similarity_scores = util.pytorch_cos_sim(user_input_embedding, predefined_embeddings) |
|
most_similar_index = similarity_scores.argmax().item() |
|
return self.prompts[most_similar_index] |
|
|
|
def generate_description(self): |
|
prompt = self.get_most_similar_sentence() |
|
image = Image.open(self.image_path) |
|
|
|
enc_image = model.encode_image(image) |
|
self.description = model.answer_question(enc_image, prompt, tokenizer) |
|
del enc_image |
|
|
|
def extract_keyphrases(self): |
|
self.keyphrases = self.kw_model.extract_keywords(self.description) |
|
def search_products(self, k=3): |
|
|
|
q = [keyword[0] for keyword in self.keyphrases if keyword[0] != 'image'] |
|
question = " ".join(q) |
|
search = GoogleSearch({ |
|
"engine": "google", |
|
|
|
"q":question, |
|
"tbm": "shop", |
|
"api_key": os.environ["API_KEY"] |
|
}) |
|
results = search.get_dict() |
|
|
|
products = results.get('shopping_results', [])[:k] |
|
return products |
|
|
|
def run(self, k=3): |
|
self.generate_description() |
|
self.extract_keyphrases() |
|
results = self.search_products(k) |
|
return results |
|
|
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self,path=""): |
|
pass |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: dict): A dictionary containing the inputs. |
|
message (:obj: str): The user message. |
|
image (:obj: str): The base64-encoded image content. |
|
Return: |
|
A list of dictionaries containing the product search results. |
|
""" |
|
inputs = data.get("inputs", {}) |
|
message = inputs.get("message") |
|
image_content = inputs.get("image") |
|
|
|
|
|
image_bytes = base64.b64decode(image_content) |
|
|
|
|
|
image_path = "input/temp_image.jpg" |
|
os.makedirs("input", exist_ok=True) |
|
with open(image_path, "wb") as f: |
|
f.write(image_bytes) |
|
|
|
|
|
searcher = ProductSearcher(message, image_path) |
|
|
|
|
|
results = searcher.run(k=3) |
|
del searcher |
|
|
|
return results |
|
|
|
|