Phoenix21 commited on
Commit
9dc639f
·
verified ·
1 Parent(s): ce9f68f

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +169 -0
pipeline.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline.py
2
+ import os
3
+ import getpass
4
+ import pandas as pd
5
+ from typing import Optional
6
+
7
+ from langchain.docstore.document import Document
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from langchain.vectorstores import FAISS
10
+ from langchain.chains import RetrievalQA
11
+
12
+ from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
+ import litellm
14
+
15
+ # We import the chain builders from our separate files
16
+ from classification_chain import get_classification_chain
17
+ from refusal_chain import get_refusal_chain
18
+ from tailor_chain import get_tailor_chain
19
+ from cleaner_chain import get_cleaner_chain, CleanerChain
20
+
21
+ # We also import the relevant RAG logic here or define it directly
22
+ # (We define build_rag_chain in this file for clarity)
23
+
24
+ ###############################################################################
25
+ # 1) Environment: set up keys if missing
26
+ ###############################################################################
27
+ if not os.environ.get("GEMINI_API_KEY"):
28
+ os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
29
+ if not os.environ.get("GROQ_API_KEY"):
30
+ os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
31
+
32
+ ###############################################################################
33
+ # 2) build_or_load_vectorstore
34
+ ###############################################################################
35
+ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
36
+ if os.path.exists(store_dir):
37
+ print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
38
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
39
+ vectorstore = FAISS.load_local(store_dir, embeddings)
40
+ return vectorstore
41
+ else:
42
+ print(f"DEBUG: Building new store from CSV: {csv_path}")
43
+ df = pd.read_csv(csv_path)
44
+ df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
45
+ df.columns = df.columns.str.strip()
46
+ if "Answer" in df.columns:
47
+ df.rename(columns={"Answer": "Answers"}, inplace=True)
48
+ if "Question" not in df.columns and "Question " in df.columns:
49
+ df.rename(columns={"Question ": "Question"}, inplace=True)
50
+ if "Question" not in df.columns or "Answers" not in df.columns:
51
+ raise ValueError("CSV must have 'Question' and 'Answers' columns.")
52
+ docs = []
53
+ for _, row in df.iterrows():
54
+ q = str(row["Question"])
55
+ ans = str(row["Answers"])
56
+ doc = Document(page_content=ans, metadata={"question": q})
57
+ docs.append(doc)
58
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
59
+ vectorstore = FAISS.from_documents(docs, embedding=embeddings)
60
+ vectorstore.save_local(store_dir)
61
+ return vectorstore
62
+
63
+ ###############################################################################
64
+ # 3) Build RAG chain for Gemini
65
+ ###############################################################################
66
+ from langchain.llms.base import LLM
67
+ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
68
+ class GeminiLangChainLLM(LLM):
69
+ def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
70
+ messages = [{"role": "user", "content": prompt}]
71
+ return llm_model(messages, stop_sequences=stop)
72
+ @property
73
+ def _llm_type(self) -> str:
74
+ return "custom_gemini"
75
+ retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
76
+ gemini_as_llm = GeminiLangChainLLM()
77
+ rag_chain = RetrievalQA.from_chain_type(
78
+ llm=gemini_as_llm,
79
+ chain_type="stuff",
80
+ retriever=retriever,
81
+ return_source_documents=True
82
+ )
83
+ return rag_chain
84
+
85
+ ###############################################################################
86
+ # 4) Initialize all the separate chains
87
+ ###############################################################################
88
+ # Classification chain
89
+ classification_chain = get_classification_chain()
90
+ # Refusal chain
91
+ refusal_chain = get_refusal_chain()
92
+ # Tailor chain
93
+ tailor_chain = get_tailor_chain()
94
+ # Cleaner chain
95
+ cleaner_chain = get_cleaner_chain()
96
+
97
+ ###############################################################################
98
+ # 5) Build our vectorstores + RAG chains
99
+ ###############################################################################
100
+ wellness_csv = "AIChatbot.csv"
101
+ brand_csv = "BrandAI.csv"
102
+ wellness_store_dir = "faiss_wellness_store"
103
+ brand_store_dir = "faiss_brand_store"
104
+
105
+ wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
106
+ brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
107
+
108
+ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
109
+ wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
110
+ brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
111
+
112
+ ###############################################################################
113
+ # 6) Tools / Agents for web search
114
+ ###############################################################################
115
+ search_tool = DuckDuckGoSearchTool()
116
+ web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
117
+ managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
118
+ manager_agent = CodeAgent(tools=[], model=gemini_llm, managed_agents=[managed_web_agent])
119
+
120
+ def do_web_search(query: str) -> str:
121
+ print("DEBUG: Attempting web search for more info...")
122
+ search_query = f"Give me relevant info: {query}"
123
+ response = manager_agent.run(search_query)
124
+ return response
125
+
126
+ ###############################################################################
127
+ # 7) Orchestrator: run_with_chain
128
+ ###############################################################################
129
+ def run_with_chain(query: str) -> str:
130
+ print("DEBUG: Starting run_with_chain...")
131
+ # 1) Classify
132
+ class_result = classification_chain.invoke({"query": query})
133
+ classification = class_result.get("text", "").strip()
134
+ print("DEBUG: Classification =>", classification)
135
+
136
+ # If OutOfScope => refusal => tailor => return
137
+ if classification == "OutOfScope":
138
+ refusal_text = refusal_chain.run({})
139
+ final_refusal = tailor_chain.run({"response": refusal_text})
140
+ return final_refusal.strip()
141
+
142
+ # If Wellness => wellness RAG => if insufficient => web => unify => tailor
143
+ if classification == "Wellness":
144
+ rag_result = wellness_rag_chain({"query": query})
145
+ csv_answer = rag_result["result"].strip()
146
+ if not csv_answer:
147
+ web_answer = do_web_search(query)
148
+ else:
149
+ lower_ans = csv_answer.lower()
150
+ if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
151
+ web_answer = do_web_search(query)
152
+ else:
153
+ web_answer = ""
154
+ final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
155
+ final_answer = tailor_chain.run({"response": final_merged})
156
+ return final_answer.strip()
157
+
158
+ # If Brand => brand RAG => tailor => return
159
+ if classification == "Brand":
160
+ rag_result = brand_rag_chain({"query": query})
161
+ csv_answer = rag_result["result"].strip()
162
+ final_merged = cleaner_chain.merge(kb=csv_answer, web="")
163
+ final_answer = tailor_chain.run({"response": final_merged})
164
+ return final_answer.strip()
165
+
166
+ # fallback
167
+ refusal_text = refusal_chain.run({})
168
+ final_refusal = tailor_chain.run({"response": refusal_text})
169
+ return final_refusal.strip()