Phoenix21 commited on
Commit
e8182c5
·
verified ·
1 Parent(s): da1118e

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +6 -2
pipeline.py CHANGED
@@ -8,8 +8,8 @@ from langchain.embeddings import HuggingFaceEmbeddings
8
  from langchain.vectorstores import FAISS
9
  from langchain.chains import RetrievalQA
10
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
11
- from langchain.llms.base import LLM
12
  import subprocess # Import subprocess to run shell commands
 
13
 
14
  # Import the chain builders from our separate files
15
  from classification_chain import get_classification_chain
@@ -38,6 +38,8 @@ def install_spacy_model():
38
  # Call the function to install the spaCy model if needed
39
  install_spacy_model()
40
 
 
 
41
 
42
  # Function to extract the main topic using NER
43
  def extract_main_topic(query: str) -> str:
@@ -45,7 +47,7 @@ def extract_main_topic(query: str) -> str:
45
  Extracts the main topic from the user's query using spaCy's NER.
46
  Returns the first named entity or noun found in the query.
47
  """
48
- doc = nlp(query)
49
 
50
  # Try to extract the main topic as a named entity (person, product, etc.)
51
  main_topic = None
@@ -100,9 +102,11 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
100
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
101
  messages = [{"role": "user", "content": prompt}]
102
  return llm_model(messages, stop_sequences=stop)
 
103
  @property
104
  def _llm_type(self) -> str:
105
  return "custom_gemini"
 
106
  retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
107
  gemini_as_llm = GeminiLangChainLLM()
108
  rag_chain = RetrievalQA.from_chain_type(
 
8
  from langchain.vectorstores import FAISS
9
  from langchain.chains import RetrievalQA
10
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
 
11
  import subprocess # Import subprocess to run shell commands
12
+ from langchain.llms.base import LLM # Import LLM
13
 
14
  # Import the chain builders from our separate files
15
  from classification_chain import get_classification_chain
 
38
  # Call the function to install the spaCy model if needed
39
  install_spacy_model()
40
 
41
+ # Load the spaCy model globally
42
+ nlp = spacy.load("en_core_web_sm")
43
 
44
  # Function to extract the main topic using NER
45
  def extract_main_topic(query: str) -> str:
 
47
  Extracts the main topic from the user's query using spaCy's NER.
48
  Returns the first named entity or noun found in the query.
49
  """
50
+ doc = nlp(query) # Use the globally loaded spaCy model
51
 
52
  # Try to extract the main topic as a named entity (person, product, etc.)
53
  main_topic = None
 
102
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
103
  messages = [{"role": "user", "content": prompt}]
104
  return llm_model(messages, stop_sequences=stop)
105
+
106
  @property
107
  def _llm_type(self) -> str:
108
  return "custom_gemini"
109
+
110
  retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
111
  gemini_as_llm = GeminiLangChainLLM()
112
  rag_chain = RetrievalQA.from_chain_type(