File size: 9,466 Bytes
e1392d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
# ToTSearch.py
import random
from typing import List, Dict, Any, Generator
from sentence_transformers import SentenceTransformer, util
import torch
import numpy as np
from twisted.internet import defer
from agent import AutonomousWebAgent
from mcts import MCTS, MCTSNode
import logging
from twisted.internet.defer import Deferred
logger = logging.getLogger(__name__)
class ToTNode:
def __init__(self, thought, parent=None):
self.thought = thought
self.parent = parent
self.children = []
self.visits = 0
self.value = 0
self.search_results = []
self.mcts_node = None
def add_child(self, child_thought):
child = ToTNode(child_thought, self)
self.children.append(child)
return child
def update(self, reward):
self.visits += 1
self.value += reward
class ToTSearch:
def __init__(self, agent: AutonomousWebAgent, model='all-MiniLM-L6-v2', max_depth=3, num_thoughts=3, num_simulations=100):
self.agent = agent
self.model = SentenceTransformer(model)
self.max_depth = max_depth
self.num_thoughts = num_thoughts
self.num_simulations = num_simulations
self.mcts = MCTS(initial_state="", num_simulations=num_simulations)
def generate_thoughts(self, query: str) -> List[str]:
prompt = f"""Given the query "{query}", generate {self.num_thoughts} distinct thoughts or approaches to address it.
Each thought should be a complete sentence and offer a unique perspective or solution path."""
thoughts = self.agent.generate_text(prompt).split('\n')
return [thought.strip() for thought in thoughts if thought.strip()]
def expand_thought(self, thought: str) -> List[str]:
prompt = f"""Expand on the following thought: "{thought}"
Generate {self.num_thoughts} more specific sub-thoughts or considerations.
Each sub-thought should be a complete sentence and offer additional detail or a new angle."""
expansions = self.agent.generate_text(prompt).split('\n')
return [exp.strip() for exp in expansions if exp.strip()]
def evaluate_thought(self, thought: str, query: str) -> float:
thought_embedding = self.model.encode(thought)
query_embedding = self.model.encode(query)
return util.pytorch_cos_sim(thought_embedding, query_embedding).item()
@defer.inlineCallbacks
def search_and_augment(self, thought: str) -> Generator[Deferred, Any, List[Dict[str, Any]]]:
search_results = yield self.agent.retrieve_from_web(thought)
for result in search_results:
result['originating_thought'] = thought
defer.returnValue(search_results)
def select(self, node: ToTNode) -> ToTNode:
while node.children:
# Choose a node with zero visits or select based on the value/visits ratio
if any(child.visits == 0 for child in node.children):
zero_visit_nodes = [child for child in node.children if child.visits == 0]
selected_node = random.choice(zero_visit_nodes)
logger.debug(f"Selected node with 0 visits: {selected_node.thought}")
return selected_node
else:
selected_node = max(node.children, key=lambda child: (child.value / child.visits) if child.visits > 0 else float('-inf'))
logger.debug(f"Selected node based on value/visits ratio: {selected_node.thought}, value: {selected_node.value}, visits: {selected_node.visits}")
return selected_node
return node
def expand(self, node: ToTNode, query: str) -> ToTNode:
if not node.children and len(node.thought.split()) > 2:
expansions = self.expand_thought(node.thought)
for expansion in expansions:
node.add_child(expansion)
return random.choice(node.children) if node.children else node
@defer.inlineCallbacks
def simulate(self, node: ToTNode, query: str):
current_node = node
depth = 0
while depth < self.max_depth:
if not current_node.children:
break
current_node = random.choice(current_node.children)
depth += 1
logger.debug(f"Simulating for thought: {current_node.thought}")
search_results = yield self.search_and_augment(current_node.thought)
current_node.search_results = search_results
logger.debug(f"Search results count: {len(search_results)}")
ranked_results = self.agent.calculate_reward(current_node.thought, query)
logger.debug(f"Ranked results: {ranked_results}")
mcts_node = MCTSNode(current_node.thought)
current_node.mcts_node = mcts_node
mcts_total_reward = 0
for _ in range(self.num_simulations):
mcts_reward = yield self.mcts.simulate(mcts_node)
mcts_total_reward += mcts_reward
self.mcts.backpropagate(mcts_node, mcts_reward)
logger.debug(f"MCTS node visits: {mcts_node.visits}, total reward: {mcts_total_reward}")
if mcts_node.visits == 0 or ranked_results == 0:
logger.warning(f"Avoiding division by zero. MCTS visits: {mcts_node.visits}, Ranked results: {ranked_results}")
combined_reward = 0
else:
combined_reward = (ranked_results + mcts_value) / 2
if mcts_node.visits > 0:
mcts_value = mcts_total_reward / mcts_node.visits
logger.debug(f"MCTS value: {mcts_value}")
else:
mcts_value = 0
logger.warning(f"MCTS node has 0 visits, assigning value 0")
combined_reward = (ranked_results + mcts_value) / 2
logger.debug(f"Combined reward: {combined_reward}")
defer.returnValue(combined_reward)
def backpropagate(self, node: ToTNode, reward: float):
while node:
node.update(reward)
node = node.parent
@defer.inlineCallbacks
def tot_search(self, query: str) -> Generator[Deferred, Any, ToTNode]:
root = ToTNode(query)
for _ in range(self.num_simulations):
node = self.select(root)
node = self.expand(node, query)
reward = yield self.simulate(node, query)
self.backpropagate(node, reward)
# Update agent's experience replay
state = self.agent.extract_features(node.thought, query)
next_state = self.agent.extract_features(node.children[0].thought if node.children else node.thought, query)
self.agent.remember_worker(state, 0, reward, next_state, False)
# Perform agent's replay to update RL models
self.agent.replay_worker()
self.agent.replay_manager()
defer.returnValue(root)
def get_best_path(self, root: ToTNode) -> List[str]:
path = [root.thought]
current = root
while current.children:
current = max(current.children, key=lambda child: child.value / child.visits if child.visits > 0 else float('-inf'))
path.append(current.thought)
return path
@defer.inlineCallbacks
def synthesize_results(self, root: ToTNode, query: str) -> Generator[Deferred, Any, str]:
best_path = self.get_best_path(root)
all_results = []
def collect_results(node):
all_results.extend(node.search_results)
for child in node.children:
collect_results(child)
collect_results(root)
# Sort results by relevance
all_results.sort(key=lambda x: self.evaluate_thought(x['content'], query), reverse=True)
# Generate a summary of the top results
top_results = all_results[:5] # Adjust the number as needed
summary_prompt = f"Synthesize the following information into a coherent answer for the query '{query}':\n\n"
summary_prompt += f"Thought path: {' -> '.join(best_path)}\n\n"
for result in top_results:
summary_prompt += f"- {result['content'][:200]}...\n"
# Use the agent's RAG capabilities for final answer generation
final_answer = yield self.agent.generate_rag_response(query, top_results)
# Save the generated answer and thought path to the agent's knowledge base
self.agent.add_document_to_kb(
title=f"ToT Search Result: {query}",
content=final_answer,
metadata={"thought_path": best_path}
)
defer.returnValue(final_answer)
@defer.inlineCallbacks
def search(self, query: str) -> Generator[Deferred, Any, str]:
logger.info(f"Starting ToT search for query: {query}")
root = yield self.tot_search(query)
final_answer = yield self.synthesize_results(root, query)
logger.info(f"ToT search completed for query: {query}")
defer.returnValue(final_answer)
# Usage example:
# tot_search = ToTSearch(agent)
# final_answer = yield tot_search.search("What are the latest advancements in renewable energy?") |