|
|
|
import random
|
|
from typing import List, Dict, Any, Generator
|
|
from sentence_transformers import SentenceTransformer, util
|
|
import torch
|
|
import numpy as np
|
|
from twisted.internet import defer
|
|
from agent import AutonomousWebAgent
|
|
from mcts import MCTS, MCTSNode
|
|
import logging
|
|
from twisted.internet.defer import Deferred
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ToTNode:
|
|
def __init__(self, thought, parent=None):
|
|
self.thought = thought
|
|
self.parent = parent
|
|
self.children = []
|
|
self.visits = 0
|
|
self.value = 0
|
|
self.search_results = []
|
|
self.mcts_node = None
|
|
|
|
def add_child(self, child_thought):
|
|
child = ToTNode(child_thought, self)
|
|
self.children.append(child)
|
|
return child
|
|
|
|
def update(self, reward):
|
|
self.visits += 1
|
|
self.value += reward
|
|
|
|
class ToTSearch:
|
|
def __init__(self, agent: AutonomousWebAgent, model='all-MiniLM-L6-v2', max_depth=3, num_thoughts=3, num_simulations=100):
|
|
self.agent = agent
|
|
self.model = SentenceTransformer(model)
|
|
self.max_depth = max_depth
|
|
self.num_thoughts = num_thoughts
|
|
self.num_simulations = num_simulations
|
|
self.mcts = MCTS(initial_state="", num_simulations=num_simulations)
|
|
|
|
def generate_thoughts(self, query: str) -> List[str]:
|
|
prompt = f"""Given the query "{query}", generate {self.num_thoughts} distinct thoughts or approaches to address it.
|
|
Each thought should be a complete sentence and offer a unique perspective or solution path."""
|
|
|
|
thoughts = self.agent.generate_text(prompt).split('\n')
|
|
return [thought.strip() for thought in thoughts if thought.strip()]
|
|
|
|
def expand_thought(self, thought: str) -> List[str]:
|
|
prompt = f"""Expand on the following thought: "{thought}"
|
|
Generate {self.num_thoughts} more specific sub-thoughts or considerations.
|
|
Each sub-thought should be a complete sentence and offer additional detail or a new angle."""
|
|
|
|
expansions = self.agent.generate_text(prompt).split('\n')
|
|
return [exp.strip() for exp in expansions if exp.strip()]
|
|
|
|
def evaluate_thought(self, thought: str, query: str) -> float:
|
|
thought_embedding = self.model.encode(thought)
|
|
query_embedding = self.model.encode(query)
|
|
return util.pytorch_cos_sim(thought_embedding, query_embedding).item()
|
|
|
|
@defer.inlineCallbacks
|
|
def search_and_augment(self, thought: str) -> Generator[Deferred, Any, List[Dict[str, Any]]]:
|
|
search_results = yield self.agent.retrieve_from_web(thought)
|
|
for result in search_results:
|
|
result['originating_thought'] = thought
|
|
defer.returnValue(search_results)
|
|
|
|
def select(self, node: ToTNode) -> ToTNode:
|
|
while node.children:
|
|
|
|
if any(child.visits == 0 for child in node.children):
|
|
zero_visit_nodes = [child for child in node.children if child.visits == 0]
|
|
selected_node = random.choice(zero_visit_nodes)
|
|
logger.debug(f"Selected node with 0 visits: {selected_node.thought}")
|
|
return selected_node
|
|
else:
|
|
selected_node = max(node.children, key=lambda child: (child.value / child.visits) if child.visits > 0 else float('-inf'))
|
|
logger.debug(f"Selected node based on value/visits ratio: {selected_node.thought}, value: {selected_node.value}, visits: {selected_node.visits}")
|
|
return selected_node
|
|
return node
|
|
|
|
|
|
def expand(self, node: ToTNode, query: str) -> ToTNode:
|
|
if not node.children and len(node.thought.split()) > 2:
|
|
expansions = self.expand_thought(node.thought)
|
|
for expansion in expansions:
|
|
node.add_child(expansion)
|
|
return random.choice(node.children) if node.children else node
|
|
|
|
@defer.inlineCallbacks
|
|
def simulate(self, node: ToTNode, query: str):
|
|
current_node = node
|
|
depth = 0
|
|
while depth < self.max_depth:
|
|
if not current_node.children:
|
|
break
|
|
current_node = random.choice(current_node.children)
|
|
depth += 1
|
|
|
|
logger.debug(f"Simulating for thought: {current_node.thought}")
|
|
|
|
search_results = yield self.search_and_augment(current_node.thought)
|
|
current_node.search_results = search_results
|
|
|
|
logger.debug(f"Search results count: {len(search_results)}")
|
|
|
|
ranked_results = self.agent.calculate_reward(current_node.thought, query)
|
|
logger.debug(f"Ranked results: {ranked_results}")
|
|
|
|
mcts_node = MCTSNode(current_node.thought)
|
|
current_node.mcts_node = mcts_node
|
|
mcts_total_reward = 0
|
|
|
|
for _ in range(self.num_simulations):
|
|
mcts_reward = yield self.mcts.simulate(mcts_node)
|
|
mcts_total_reward += mcts_reward
|
|
self.mcts.backpropagate(mcts_node, mcts_reward)
|
|
|
|
logger.debug(f"MCTS node visits: {mcts_node.visits}, total reward: {mcts_total_reward}")
|
|
|
|
if mcts_node.visits == 0 or ranked_results == 0:
|
|
logger.warning(f"Avoiding division by zero. MCTS visits: {mcts_node.visits}, Ranked results: {ranked_results}")
|
|
combined_reward = 0
|
|
else:
|
|
combined_reward = (ranked_results + mcts_value) / 2
|
|
|
|
if mcts_node.visits > 0:
|
|
mcts_value = mcts_total_reward / mcts_node.visits
|
|
logger.debug(f"MCTS value: {mcts_value}")
|
|
else:
|
|
mcts_value = 0
|
|
logger.warning(f"MCTS node has 0 visits, assigning value 0")
|
|
|
|
combined_reward = (ranked_results + mcts_value) / 2
|
|
logger.debug(f"Combined reward: {combined_reward}")
|
|
|
|
defer.returnValue(combined_reward)
|
|
|
|
def backpropagate(self, node: ToTNode, reward: float):
|
|
while node:
|
|
node.update(reward)
|
|
node = node.parent
|
|
|
|
@defer.inlineCallbacks
|
|
def tot_search(self, query: str) -> Generator[Deferred, Any, ToTNode]:
|
|
root = ToTNode(query)
|
|
for _ in range(self.num_simulations):
|
|
node = self.select(root)
|
|
node = self.expand(node, query)
|
|
reward = yield self.simulate(node, query)
|
|
self.backpropagate(node, reward)
|
|
|
|
|
|
state = self.agent.extract_features(node.thought, query)
|
|
next_state = self.agent.extract_features(node.children[0].thought if node.children else node.thought, query)
|
|
self.agent.remember_worker(state, 0, reward, next_state, False)
|
|
|
|
|
|
self.agent.replay_worker()
|
|
self.agent.replay_manager()
|
|
|
|
defer.returnValue(root)
|
|
|
|
def get_best_path(self, root: ToTNode) -> List[str]:
|
|
path = [root.thought]
|
|
current = root
|
|
while current.children:
|
|
current = max(current.children, key=lambda child: child.value / child.visits if child.visits > 0 else float('-inf'))
|
|
path.append(current.thought)
|
|
return path
|
|
|
|
@defer.inlineCallbacks
|
|
def synthesize_results(self, root: ToTNode, query: str) -> Generator[Deferred, Any, str]:
|
|
best_path = self.get_best_path(root)
|
|
all_results = []
|
|
|
|
def collect_results(node):
|
|
all_results.extend(node.search_results)
|
|
for child in node.children:
|
|
collect_results(child)
|
|
|
|
collect_results(root)
|
|
|
|
|
|
all_results.sort(key=lambda x: self.evaluate_thought(x['content'], query), reverse=True)
|
|
|
|
|
|
top_results = all_results[:5]
|
|
summary_prompt = f"Synthesize the following information into a coherent answer for the query '{query}':\n\n"
|
|
summary_prompt += f"Thought path: {' -> '.join(best_path)}\n\n"
|
|
for result in top_results:
|
|
summary_prompt += f"- {result['content'][:200]}...\n"
|
|
|
|
|
|
final_answer = yield self.agent.generate_rag_response(query, top_results)
|
|
|
|
|
|
self.agent.add_document_to_kb(
|
|
title=f"ToT Search Result: {query}",
|
|
content=final_answer,
|
|
metadata={"thought_path": best_path}
|
|
)
|
|
|
|
defer.returnValue(final_answer)
|
|
|
|
@defer.inlineCallbacks
|
|
def search(self, query: str) -> Generator[Deferred, Any, str]:
|
|
logger.info(f"Starting ToT search for query: {query}")
|
|
root = yield self.tot_search(query)
|
|
final_answer = yield self.synthesize_results(root, query)
|
|
logger.info(f"ToT search completed for query: {query}")
|
|
defer.returnValue(final_answer)
|
|
|
|
|
|
|
|
|