File size: 8,720 Bytes
e1392d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
# mcts.py
import math
import random
from nltk.corpus import wordnet
from scrapy.crawler import CrawlerRunner
from scrapy.utils.log import configure_logging
from scrapy.utils.project import get_project_settings
from twisted.internet import reactor, defer
from scrapy import signals
import logging
from my_search_engine.my_search_engine.spiders.search_spider import SearchSpider
from sentence_transformers import SentenceTransformer, util
from ranking import train_ranking_model
import time
logger = logging.getLogger(__name__)
class MCTSNode:
def __init__(self, state, parent=None, action=None):
self.state = state
self.parent = parent
self.action = action
self.children = []
self.visits = 0
self.value = 0
self.ucb_score = float('inf')
def is_leaf(self):
return len(self.children) == 0
def add_child(self, child_state, action=None):
child_node = MCTSNode(child_state, parent=self, action=action)
self.children.append(child_node)
return child_node
def update(self, reward):
self.visits += 1
self.value += reward
if self.parent: # Only calculate UCB if not root
self.ucb_score = self.calculate_ucb()
def calculate_ucb(self, exploration_weight=1.41):
if self.visits == 0 or not self.parent:
return float('inf')
exploitation = self.value / self.visits
exploration = exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
return exploitation + exploration
class MCTS:
def __init__(self, initial_state, num_simulations=20, exploration_weight=1.41):
self.root = MCTSNode(initial_state)
self.num_simulations = num_simulations
self.exploration_weight = exploration_weight
self.query_model = SentenceTransformer('all-MiniLM-L6-v2')
self.results = []
self.crawler_runner = CrawlerRunner(get_project_settings())
self.initial_state = initial_state
self.num_iterations = 5
def select(self, node):
while not node.is_leaf():
if not node.children:
return node
node = max(node.children, key=lambda c: c.calculate_ucb(self.exploration_weight))
return node
def expand(self, node):
if node.visits == 0:
return node
possible_refinements = self.get_possible_refinements(node.state)
for refinement in possible_refinements:
node.add_child(refinement)
return random.choice(node.children) if node.children else node
def calculate_combined_reward(self, ranking_score, state):
state_length_reward = len(state) / 100
if state:
query_complexity = len(set(state.split())) / len(state.split())
else:
query_complexity = 0
semantic_similarity = self.calculate_semantic_similarity(state, self.root.state)
combined_reward = (
0.5 * ranking_score +
0.2 * state_length_reward +
0.2 * query_complexity +
0.1 * semantic_similarity
)
return combined_reward
def calculate_semantic_similarity(self, query1, query2):
embedding1 = self.query_model.encode(query1)
embedding2 = self.query_model.encode(query2)
return util.pytorch_cos_sim(embedding1, embedding2).item()
def backpropagate(self, node, reward):
while node is not None:
node.update(reward)
node = node.parent
def best_action(self):
if not self.root.children:
return self.root
def score(node):
if node.visits == 0:
return float('-inf')
return node.value / node.visits
return max(self.root.children, key=score)
def refine_query(self, query):
words = query.split()
refined_query = []
for word in words:
if word.lower() not in {"how", "to", "get", "an", "the", "and", "or", "of", "build"}:
synonyms = wordnet.synsets(word)
if synonyms:
synonym_words = [lemma.name() for lemma in synonyms[0].lemmas()
if len(lemma.name().split()) == 1 and word != lemma.name()]
if synonym_words:
refined_query.append(random.choice(synonym_words))
else:
refined_query.append(word)
else:
refined_query.append(word)
else:
refined_query.append(word)
possible_intent_keywords = ['guide', 'tutorial', 'LLM', 'language model', 'NLP', 'GPT']
refined_query.append(random.choice(possible_intent_keywords))
return ' '.join(refined_query)
def get_related_queries(self, query):
query_embedding = self.query_model.encode(query)
refined_query_variations = [query]
words_to_avoid = {'how', 'to', 'get'}
words = query.split()
for word in words:
if word.lower() not in words_to_avoid:
synonyms = wordnet.synsets(word)
if synonyms:
synonym_words = [lemma.name() for lemma in synonyms[0].lemmas() if lemma.name() != word]
if synonym_words:
refined_query = query.replace(word, random.choice(synonym_words))
refined_query_variations.append(refined_query)
refined_query_variations = list(set(refined_query_variations))
refined_query_embeddings = [self.query_model.encode(variation) for variation in refined_query_variations]
similarity_scores = util.pytorch_cos_sim(query_embedding, refined_query_embeddings).tolist()[0]
similarity_threshold = 0.8
filtered_queries = [variation for idx, variation in enumerate(refined_query_variations)
if similarity_scores[idx] > similarity_threshold]
return filtered_queries[:2] if filtered_queries else [query]
def get_possible_refinements(self, query):
refined_queries = self.get_related_queries(query)
return refined_queries + [self.refine_query(query)]
@defer.inlineCallbacks
def web_search(self, query, search_sites=None):
if not query.strip():
logger.error("Cannot perform web search with an empty query.")
defer.returnValue([])
logger.info(f"Starting web search for query: {query}")
configure_logging(install_root_handler=False)
logging.basicConfig(level=logging.INFO)
results = []
def crawler_results(item, response, spider):
logger.info(f"Received result: {item['title']}")
results.append(item)
try:
crawler = self.crawler_runner.create_crawler(SearchSpider)
crawler.signals.connect(crawler_results, signal=signals.item_scraped)
# Start crawling, passing query and search_sites to the spider
yield self.crawler_runner.crawl(crawler, query=query, search_sites=search_sites)
except Exception as e:
logger.error(f"Error during web search: {str(e)}")
defer.returnValue([])
logger.info(f"Web search completed. Found {len(results)} results.")
defer.returnValue(results)
@defer.inlineCallbacks
def run(self):
logger.info(f"Starting MCTS run with {self.num_iterations} iterations")
for i in range(self.num_iterations):
logger.debug(f"Iteration {i+1}/{self.num_iterations}")
leaf = self.select(self.root)
child = self.expand(leaf)
reward = yield self.simulate(child)
self.backpropagate(child, reward)
best_child = self.best_action()
logger.info(f"MCTS run completed. Best action: {best_child.state}")
defer.returnValue(best_child.state if best_child != self.root else self.root.state)
@defer.inlineCallbacks
def simulate(self, node):
query_results = yield self.web_search(node.state)
ranked_results = train_ranking_model(node.state, query_results)
if ranked_results:
top_score = ranked_results[0]['predicted_score']
else:
top_score = 0
reward = self.calculate_combined_reward(top_score, node.state)
defer.returnValue(reward)
|