Phoenix21 commited on
Commit
864c041
·
verified ·
1 Parent(s): bbd2528

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +50 -90
pipeline.py CHANGED
@@ -1,18 +1,16 @@
1
- # pipeline.py
2
  import os
3
  import getpass
 
4
  import pandas as pd
5
  from typing import Optional
6
-
7
  from langchain.docstore.document import Document
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
  from langchain.vectorstores import FAISS
10
  from langchain.chains import RetrievalQA
11
-
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
  import litellm
14
 
15
- # We import the chain builders from our separate files
16
  from classification_chain import get_classification_chain
17
  from refusal_chain import get_refusal_chain
18
  from tailor_chain import get_tailor_chain
@@ -21,82 +19,52 @@ from cleaner_chain import get_cleaner_chain, CleanerChain
21
  # We also import the relevant RAG logic here or define it directly
22
  # (We define build_rag_chain in this file for clarity)
23
 
24
- ###############################################################################
25
  # 1) Environment: set up keys if missing
26
- ###############################################################################
27
  if not os.environ.get("GEMINI_API_KEY"):
28
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
29
  if not os.environ.get("GROQ_API_KEY"):
30
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
31
 
32
- ###############################################################################
33
- # 2) build_or_load_vectorstore
34
- ###############################################################################
35
- def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
36
- if os.path.exists(store_dir):
37
- print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
38
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
39
- vectorstore = FAISS.load_local(store_dir, embeddings)
40
- return vectorstore
41
- else:
42
- print(f"DEBUG: Building new store from CSV: {csv_path}")
43
- df = pd.read_csv(csv_path)
44
- df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
45
- df.columns = df.columns.str.strip()
46
- if "Answer" in df.columns:
47
- df.rename(columns={"Answer": "Answers"}, inplace=True)
48
- if "Question" not in df.columns and "Question " in df.columns:
49
- df.rename(columns={"Question ": "Question"}, inplace=True)
50
- if "Question" not in df.columns or "Answers" not in df.columns:
51
- raise ValueError("CSV must have 'Question' and 'Answers' columns.")
52
- docs = []
53
- for _, row in df.iterrows():
54
- q = str(row["Question"])
55
- ans = str(row["Answers"])
56
- doc = Document(page_content=ans, metadata={"question": q})
57
- docs.append(doc)
58
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
59
- vectorstore = FAISS.from_documents(docs, embedding=embeddings)
60
- vectorstore.save_local(store_dir)
61
- return vectorstore
62
-
63
- ###############################################################################
64
- # 3) Build RAG chain for Gemini
65
- ###############################################################################
66
- from langchain.llms.base import LLM
67
- def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
68
- class GeminiLangChainLLM(LLM):
69
- def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
70
- messages = [{"role": "user", "content": prompt}]
71
- return llm_model(messages, stop_sequences=stop)
72
- @property
73
- def _llm_type(self) -> str:
74
- return "custom_gemini"
75
- retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
76
- gemini_as_llm = GeminiLangChainLLM()
77
- rag_chain = RetrievalQA.from_chain_type(
78
- llm=gemini_as_llm,
79
- chain_type="stuff",
80
- retriever=retriever,
81
- return_source_documents=True
82
- )
83
- return rag_chain
84
-
85
- ###############################################################################
86
- # 4) Initialize all the separate chains
87
- ###############################################################################
88
- # Classification chain
89
  classification_chain = get_classification_chain()
90
- # Refusal chain
91
- refusal_chain = get_refusal_chain()
92
- # Tailor chain
93
  tailor_chain = get_tailor_chain()
94
- # Cleaner chain
95
  cleaner_chain = get_cleaner_chain()
96
 
97
- ###############################################################################
98
- # 5) Build our vectorstores + RAG chains
99
- ###############################################################################
100
  wellness_csv = "AIChatbot.csv"
101
  brand_csv = "BrandAI.csv"
102
  wellness_store_dir = "faiss_wellness_store"
@@ -109,33 +77,25 @@ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("
109
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
110
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
111
 
112
- ###############################################################################
113
- # 6) Tools / Agents for web search
114
- ###############################################################################
115
- search_tool = DuckDuckGoSearchTool()
116
- web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
117
- managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
118
- manager_agent = CodeAgent(tools=[], model=gemini_llm, managed_agents=[managed_web_agent])
119
-
120
- def do_web_search(query: str) -> str:
121
- print("DEBUG: Attempting web search for more info...")
122
- search_query = f"Give me relevant info: {query}"
123
- response = manager_agent.run(search_query)
124
- return response
125
-
126
- ###############################################################################
127
- # 7) Orchestrator: run_with_chain
128
- ###############################################################################
129
  def run_with_chain(query: str) -> str:
130
  print("DEBUG: Starting run_with_chain...")
131
- # 1) Classify
 
132
  class_result = classification_chain.invoke({"query": query})
133
  classification = class_result.get("text", "").strip()
134
  print("DEBUG: Classification =>", classification)
135
 
136
  # If OutOfScope => refusal => tailor => return
137
  if classification == "OutOfScope":
138
- refusal_text = refusal_chain.run({})
 
 
 
 
 
139
  final_refusal = tailor_chain.run({"response": refusal_text})
140
  return final_refusal.strip()
141
 
@@ -164,6 +124,6 @@ def run_with_chain(query: str) -> str:
164
  return final_answer.strip()
165
 
166
  # fallback
167
- refusal_text = refusal_chain.run({})
168
  final_refusal = tailor_chain.run({"response": refusal_text})
169
  return final_refusal.strip()
 
 
1
  import os
2
  import getpass
3
+ import spacy # Import spaCy for NER functionality
4
  import pandas as pd
5
  from typing import Optional
 
6
  from langchain.docstore.document import Document
7
  from langchain.embeddings import HuggingFaceEmbeddings
8
  from langchain.vectorstores import FAISS
9
  from langchain.chains import RetrievalQA
 
10
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
11
  import litellm
12
 
13
+ # Import the chain builders from our separate files
14
  from classification_chain import get_classification_chain
15
  from refusal_chain import get_refusal_chain
16
  from tailor_chain import get_tailor_chain
 
19
  # We also import the relevant RAG logic here or define it directly
20
  # (We define build_rag_chain in this file for clarity)
21
 
 
22
  # 1) Environment: set up keys if missing
 
23
  if not os.environ.get("GEMINI_API_KEY"):
24
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
25
  if not os.environ.get("GROQ_API_KEY"):
26
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
27
 
28
+ # 2) Load spaCy model for NER
29
+ nlp = spacy.load("en_core_web_sm")
30
+
31
+ # Function to extract the main topic using NER
32
+ def extract_main_topic(query: str) -> str:
33
+ """
34
+ Extracts the main topic from the user's query using spaCy's NER.
35
+ Returns the first named entity or noun found in the query.
36
+ """
37
+ doc = nlp(query)
38
+
39
+ # Try to extract the main topic as a named entity (person, product, etc.)
40
+ main_topic = None
41
+ for ent in doc.ents:
42
+ # Filter for specific entity types (you can adjust this based on your needs)
43
+ if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]: # Add more entity labels as needed
44
+ main_topic = ent.text
45
+ break
46
+
47
+ # If no named entity found, fallback to extracting the first noun or proper noun
48
+ if not main_topic:
49
+ for token in doc:
50
+ if token.pos_ in ["NOUN", "PROPN"]: # Extract first noun or proper noun
51
+ main_topic = token.text
52
+ break
53
+
54
+ # Return the extracted topic or a fallback value if no topic is found
55
+ return main_topic if main_topic else "this topic"
56
+
57
+ # 3) build_or_load_vectorstore (no changes)
58
+
59
+ # 4) Build RAG chain for Gemini (no changes)
60
+
61
+ # 5) Initialize all the separate chains
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  classification_chain = get_classification_chain()
63
+ refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic
 
 
64
  tailor_chain = get_tailor_chain()
 
65
  cleaner_chain = get_cleaner_chain()
66
 
67
+ # 6) Build our vectorstores + RAG chains
 
 
68
  wellness_csv = "AIChatbot.csv"
69
  brand_csv = "BrandAI.csv"
70
  wellness_store_dir = "faiss_wellness_store"
 
77
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
78
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
79
 
80
+ # 7) Tools / Agents for web search (no changes)
81
+
82
+ # 8) Orchestrator: run_with_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def run_with_chain(query: str) -> str:
84
  print("DEBUG: Starting run_with_chain...")
85
+
86
+ # 1) Classify the query
87
  class_result = classification_chain.invoke({"query": query})
88
  classification = class_result.get("text", "").strip()
89
  print("DEBUG: Classification =>", classification)
90
 
91
  # If OutOfScope => refusal => tailor => return
92
  if classification == "OutOfScope":
93
+ # Extract the main topic for the refusal message
94
+ topic = extract_main_topic(query)
95
+ print("DEBUG: Extracted Topic =>", topic)
96
+
97
+ # Pass the extracted topic to the refusal chain
98
+ refusal_text = refusal_chain.run({"topic": topic})
99
  final_refusal = tailor_chain.run({"response": refusal_text})
100
  return final_refusal.strip()
101
 
 
124
  return final_answer.strip()
125
 
126
  # fallback
127
+ refusal_text = refusal_chain.run({"topic": "this topic"})
128
  final_refusal = tailor_chain.run({"response": refusal_text})
129
  return final_refusal.strip()