File size: 8,720 Bytes
e1392d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# mcts.py
import math
import random
from nltk.corpus import wordnet
from scrapy.crawler import CrawlerRunner
from scrapy.utils.log import configure_logging
from scrapy.utils.project import get_project_settings
from twisted.internet import reactor, defer
from scrapy import signals
import logging
from my_search_engine.my_search_engine.spiders.search_spider import SearchSpider
from sentence_transformers import SentenceTransformer, util
from ranking import train_ranking_model
import time

logger = logging.getLogger(__name__)

class MCTSNode:
    def __init__(self, state, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.action = action
        self.children = []
        self.visits = 0
        self.value = 0
        self.ucb_score = float('inf')

    def is_leaf(self):
        return len(self.children) == 0

    def add_child(self, child_state, action=None):
        child_node = MCTSNode(child_state, parent=self, action=action)
        self.children.append(child_node)
        return child_node

    def update(self, reward):
        self.visits += 1
        self.value += reward
        if self.parent:  # Only calculate UCB if not root
            self.ucb_score = self.calculate_ucb()

    def calculate_ucb(self, exploration_weight=1.41):
        if self.visits == 0 or not self.parent:
            return float('inf')
        exploitation = self.value / self.visits
        exploration = exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploitation + exploration

class MCTS:
    def __init__(self, initial_state, num_simulations=20, exploration_weight=1.41):
        self.root = MCTSNode(initial_state)
        self.num_simulations = num_simulations
        self.exploration_weight = exploration_weight
        self.query_model = SentenceTransformer('all-MiniLM-L6-v2')
        self.results = []
        self.crawler_runner = CrawlerRunner(get_project_settings())
        self.initial_state = initial_state
        self.num_iterations = 5

    def select(self, node):
        while not node.is_leaf():
            if not node.children:
                return node
            node = max(node.children, key=lambda c: c.calculate_ucb(self.exploration_weight))
        return node

    def expand(self, node):
        if node.visits == 0:
            return node
        possible_refinements = self.get_possible_refinements(node.state)
        for refinement in possible_refinements:
            node.add_child(refinement)
        return random.choice(node.children) if node.children else node

    def calculate_combined_reward(self, ranking_score, state):
        state_length_reward = len(state) / 100
        if state:
            query_complexity = len(set(state.split())) / len(state.split())
        else:
            query_complexity = 0
        semantic_similarity = self.calculate_semantic_similarity(state, self.root.state)
        
        combined_reward = (
            0.5 * ranking_score +
            0.2 * state_length_reward +
            0.2 * query_complexity +
            0.1 * semantic_similarity
        )
        return combined_reward

    def calculate_semantic_similarity(self, query1, query2):
        embedding1 = self.query_model.encode(query1)
        embedding2 = self.query_model.encode(query2)
        return util.pytorch_cos_sim(embedding1, embedding2).item()

    def backpropagate(self, node, reward):
        while node is not None:
            node.update(reward)
            node = node.parent

    def best_action(self):
        if not self.root.children:
            return self.root

        def score(node):
            if node.visits == 0:
                return float('-inf')
            return node.value / node.visits

        return max(self.root.children, key=score)

    def refine_query(self, query):
        words = query.split()
        refined_query = []
        
        for word in words:
            if word.lower() not in {"how", "to", "get", "an", "the", "and", "or", "of", "build"}:
                synonyms = wordnet.synsets(word)
                if synonyms:
                    synonym_words = [lemma.name() for lemma in synonyms[0].lemmas() 
                                    if len(lemma.name().split()) == 1 and word != lemma.name()]
                    if synonym_words:
                        refined_query.append(random.choice(synonym_words))
                    else:
                        refined_query.append(word)
                else:
                    refined_query.append(word)
            else:
                refined_query.append(word)
        
        possible_intent_keywords = ['guide', 'tutorial', 'LLM', 'language model', 'NLP', 'GPT']
        refined_query.append(random.choice(possible_intent_keywords))
        
        return ' '.join(refined_query)

    def get_related_queries(self, query):
        query_embedding = self.query_model.encode(query)
        refined_query_variations = [query]
        words_to_avoid = {'how', 'to', 'get'}
        words = query.split()
        
        for word in words:
            if word.lower() not in words_to_avoid:
                synonyms = wordnet.synsets(word)
                if synonyms:
                    synonym_words = [lemma.name() for lemma in synonyms[0].lemmas() if lemma.name() != word]
                    if synonym_words:
                        refined_query = query.replace(word, random.choice(synonym_words))
                        refined_query_variations.append(refined_query)

        refined_query_variations = list(set(refined_query_variations))
        refined_query_embeddings = [self.query_model.encode(variation) for variation in refined_query_variations]
        similarity_scores = util.pytorch_cos_sim(query_embedding, refined_query_embeddings).tolist()[0]

        similarity_threshold = 0.8
        filtered_queries = [variation for idx, variation in enumerate(refined_query_variations)
                            if similarity_scores[idx] > similarity_threshold]
        
        return filtered_queries[:2] if filtered_queries else [query]

    def get_possible_refinements(self, query):
        refined_queries = self.get_related_queries(query)
        return refined_queries + [self.refine_query(query)]

    @defer.inlineCallbacks
    def web_search(self, query, search_sites=None):
        if not query.strip():
            logger.error("Cannot perform web search with an empty query.")
            defer.returnValue([])
        
        logger.info(f"Starting web search for query: {query}")
        configure_logging(install_root_handler=False)
        logging.basicConfig(level=logging.INFO)

        results = []

        def crawler_results(item, response, spider):
            logger.info(f"Received result: {item['title']}")
            results.append(item)

        try:
            crawler = self.crawler_runner.create_crawler(SearchSpider)
            crawler.signals.connect(crawler_results, signal=signals.item_scraped)
            
            # Start crawling, passing query and search_sites to the spider
            yield self.crawler_runner.crawl(crawler, query=query, search_sites=search_sites)
        except Exception as e:
            logger.error(f"Error during web search: {str(e)}")
            defer.returnValue([])

        logger.info(f"Web search completed. Found {len(results)} results.")
        defer.returnValue(results)

    @defer.inlineCallbacks
    def run(self):
        logger.info(f"Starting MCTS run with {self.num_iterations} iterations")
        for i in range(self.num_iterations):
            logger.debug(f"Iteration {i+1}/{self.num_iterations}")
            leaf = self.select(self.root)
            child = self.expand(leaf)
            reward = yield self.simulate(child)
            self.backpropagate(child, reward)

        best_child = self.best_action()
        logger.info(f"MCTS run completed. Best action: {best_child.state}")
        defer.returnValue(best_child.state if best_child != self.root else self.root.state)
        
    @defer.inlineCallbacks
    def simulate(self, node):
        query_results = yield self.web_search(node.state)
        ranked_results = train_ranking_model(node.state, query_results)

        if ranked_results:
            top_score = ranked_results[0]['predicted_score']
        else:
            top_score = 0

        reward = self.calculate_combined_reward(top_score, node.state)
        defer.returnValue(reward)