File size: 4,292 Bytes
289dbfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c704ba4
289dbfd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
current_dir = os.getcwd()
os.environ['HF_HOME'] = os.path.join(current_dir)
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from PIL import Image
from serpapi import GoogleSearch
from keybert import KeyBERT
from typing import Dict, Any, List
import base64
import torch
model_id = "vikhyatk/moondream2"
revision = "2024-08-26"
model = AutoModelForCausalLM.from_pretrained(
    model_id, trust_remote_code=True, revision=revision
)

model.to('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)

model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
sentence_model = SentenceTransformer(model_name, device='cuda')

class ProductSearcher:
    def __init__(self, user_input, image_path):
        self.user_input = user_input
        self.image_path = image_path
        self.predefined_questions = [
            "tôi muốn mua sản phẩm này",
            "tôi muốn thông tin về sản phẩm",
            "tôi muốn biết giá cái này"
        ]
        self.prompts = [
            "Descibe product in image with it color. Only answer in one sentence"
            "Describe the product in detail and provide information about the product. If you don't know the product, you can describe the image",
            "Estimate the price of the product and provide a detailed description of the product"
        ]
        self.description = ''
        self.keyphrases = []
        self.kw_model= KeyBERT()


    def get_most_similar_sentence(self):
        user_input_embedding = sentence_model.encode(self.user_input)
        predefined_embeddings = sentence_model.encode(self.predefined_questions)
        similarity_scores = util.pytorch_cos_sim(user_input_embedding, predefined_embeddings)
        most_similar_index = similarity_scores.argmax().item()
        return self.prompts[most_similar_index]

    def generate_description(self):
        prompt = self.get_most_similar_sentence()
        image = Image.open(self.image_path)

        enc_image = model.encode_image(image)
        self.description = model.answer_question(enc_image, prompt, tokenizer)
        del enc_image

    def extract_keyphrases(self):
        self.keyphrases = self.kw_model.extract_keywords(self.description)
    def search_products(self, k=3):
        # Concatenate keyphrases to form a question
        q = [keyword[0] for keyword in self.keyphrases if keyword[0] != 'image']
        question = " ".join(q)
        search = GoogleSearch({
            "engine": "google",
            # "q": self.keyphrases[0]['word'],
            "q":question,
            "tbm": "shop",
            "api_key": os.environ["API_KEY"]
        })
        results = search.get_dict()
        # Extract top k products from the search results
        products = results.get('shopping_results', [])[:k]
        return products

    def run(self, k=3):
        self.generate_description()
        self.extract_keyphrases()
        results = self.search_products(k)
        return results



class EndpointHandler:
    def __init__(self,path=""):
        pass

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj: dict): A dictionary containing the inputs.
                message (:obj: str): The user message.
                image (:obj: str): The base64-encoded image content.
        Return:
            A list of dictionaries containing the product search results.
        """
        inputs = data.get("inputs", {})
        message = inputs.get("message")
        image_content = inputs.get("image")

        # Decode the base64-encoded image content
        image_bytes = base64.b64decode(image_content)

        # Save the image to a temporary file
        image_path = "input/temp_image.jpg"
        os.makedirs("input", exist_ok=True)
        with open(image_path, "wb") as f:
            f.write(image_bytes)

        # Initialize ProductSearcher with the message and image path
        searcher = ProductSearcher(message, image_path)

        # Run the search and get results
        results = searcher.run(k=3)
        del searcher
        # Return the search results
        return results