Phoenix21 commited on
Commit
c299b66
·
verified ·
1 Parent(s): 864c041

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +54 -3
pipeline.py CHANGED
@@ -16,9 +16,6 @@ from refusal_chain import get_refusal_chain
16
  from tailor_chain import get_tailor_chain
17
  from cleaner_chain import get_cleaner_chain, CleanerChain
18
 
19
- # We also import the relevant RAG logic here or define it directly
20
- # (We define build_rag_chain in this file for clarity)
21
-
22
  # 1) Environment: set up keys if missing
23
  if not os.environ.get("GEMINI_API_KEY"):
24
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
@@ -55,8 +52,52 @@ def extract_main_topic(query: str) -> str:
55
  return main_topic if main_topic else "this topic"
56
 
57
  # 3) build_or_load_vectorstore (no changes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # 4) Build RAG chain for Gemini (no changes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # 5) Initialize all the separate chains
62
  classification_chain = get_classification_chain()
@@ -78,6 +119,16 @@ wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
78
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
79
 
80
  # 7) Tools / Agents for web search (no changes)
 
 
 
 
 
 
 
 
 
 
81
 
82
  # 8) Orchestrator: run_with_chain
83
  def run_with_chain(query: str) -> str:
 
16
  from tailor_chain import get_tailor_chain
17
  from cleaner_chain import get_cleaner_chain, CleanerChain
18
 
 
 
 
19
  # 1) Environment: set up keys if missing
20
  if not os.environ.get("GEMINI_API_KEY"):
21
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
 
52
  return main_topic if main_topic else "this topic"
53
 
54
  # 3) build_or_load_vectorstore (no changes)
55
+ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
56
+ if os.path.exists(store_dir):
57
+ print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
58
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
59
+ vectorstore = FAISS.load_local(store_dir, embeddings)
60
+ return vectorstore
61
+ else:
62
+ print(f"DEBUG: Building new store from CSV: {csv_path}")
63
+ df = pd.read_csv(csv_path)
64
+ df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
65
+ df.columns = df.columns.str.strip()
66
+ if "Answer" in df.columns:
67
+ df.rename(columns={"Answer": "Answers"}, inplace=True)
68
+ if "Question" not in df.columns and "Question " in df.columns:
69
+ df.rename(columns={"Question ": "Question"}, inplace=True)
70
+ if "Question" not in df.columns or "Answers" not in df.columns:
71
+ raise ValueError("CSV must have 'Question' and 'Answers' columns.")
72
+ docs = []
73
+ for _, row in df.iterrows():
74
+ q = str(row["Question"])
75
+ ans = str(row["Answers"])
76
+ doc = Document(page_content=ans, metadata={"question": q})
77
+ docs.append(doc)
78
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
79
+ vectorstore = FAISS.from_documents(docs, embedding=embeddings)
80
+ vectorstore.save_local(store_dir)
81
+ return vectorstore
82
 
83
  # 4) Build RAG chain for Gemini (no changes)
84
+ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
85
+ class GeminiLangChainLLM(LLM):
86
+ def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
87
+ messages = [{"role": "user", "content": prompt}]
88
+ return llm_model(messages, stop_sequences=stop)
89
+ @property
90
+ def _llm_type(self) -> str:
91
+ return "custom_gemini"
92
+ retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
93
+ gemini_as_llm = GeminiLangChainLLM()
94
+ rag_chain = RetrievalQA.from_chain_type(
95
+ llm=gemini_as_llm,
96
+ chain_type="stuff",
97
+ retriever=retriever,
98
+ return_source_documents=True
99
+ )
100
+ return rag_chain
101
 
102
  # 5) Initialize all the separate chains
103
  classification_chain = get_classification_chain()
 
119
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
120
 
121
  # 7) Tools / Agents for web search (no changes)
122
+ search_tool = DuckDuckGoSearchTool()
123
+ web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
124
+ managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
125
+ manager_agent = CodeAgent(tools=[], model=gemini_llm, managed_agents=[managed_web_agent])
126
+
127
+ def do_web_search(query: str) -> str:
128
+ print("DEBUG: Attempting web search for more info...")
129
+ search_query = f"Give me relevant info: {query}"
130
+ response = manager_agent.run(search_query)
131
+ return response
132
 
133
  # 8) Orchestrator: run_with_chain
134
  def run_with_chain(query: str) -> str: