Spaces:
Running
Running
Update pipeline.py
Browse files- pipeline.py +6 -2
pipeline.py
CHANGED
@@ -8,8 +8,8 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
|
8 |
from langchain.vectorstores import FAISS
|
9 |
from langchain.chains import RetrievalQA
|
10 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
11 |
-
from langchain.llms.base import LLM
|
12 |
import subprocess # Import subprocess to run shell commands
|
|
|
13 |
|
14 |
# Import the chain builders from our separate files
|
15 |
from classification_chain import get_classification_chain
|
@@ -38,6 +38,8 @@ def install_spacy_model():
|
|
38 |
# Call the function to install the spaCy model if needed
|
39 |
install_spacy_model()
|
40 |
|
|
|
|
|
41 |
|
42 |
# Function to extract the main topic using NER
|
43 |
def extract_main_topic(query: str) -> str:
|
@@ -45,7 +47,7 @@ def extract_main_topic(query: str) -> str:
|
|
45 |
Extracts the main topic from the user's query using spaCy's NER.
|
46 |
Returns the first named entity or noun found in the query.
|
47 |
"""
|
48 |
-
doc = nlp(query)
|
49 |
|
50 |
# Try to extract the main topic as a named entity (person, product, etc.)
|
51 |
main_topic = None
|
@@ -100,9 +102,11 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
|
100 |
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
|
101 |
messages = [{"role": "user", "content": prompt}]
|
102 |
return llm_model(messages, stop_sequences=stop)
|
|
|
103 |
@property
|
104 |
def _llm_type(self) -> str:
|
105 |
return "custom_gemini"
|
|
|
106 |
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
107 |
gemini_as_llm = GeminiLangChainLLM()
|
108 |
rag_chain = RetrievalQA.from_chain_type(
|
|
|
8 |
from langchain.vectorstores import FAISS
|
9 |
from langchain.chains import RetrievalQA
|
10 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
|
|
11 |
import subprocess # Import subprocess to run shell commands
|
12 |
+
from langchain.llms.base import LLM # Import LLM
|
13 |
|
14 |
# Import the chain builders from our separate files
|
15 |
from classification_chain import get_classification_chain
|
|
|
38 |
# Call the function to install the spaCy model if needed
|
39 |
install_spacy_model()
|
40 |
|
41 |
+
# Load the spaCy model globally
|
42 |
+
nlp = spacy.load("en_core_web_sm")
|
43 |
|
44 |
# Function to extract the main topic using NER
|
45 |
def extract_main_topic(query: str) -> str:
|
|
|
47 |
Extracts the main topic from the user's query using spaCy's NER.
|
48 |
Returns the first named entity or noun found in the query.
|
49 |
"""
|
50 |
+
doc = nlp(query) # Use the globally loaded spaCy model
|
51 |
|
52 |
# Try to extract the main topic as a named entity (person, product, etc.)
|
53 |
main_topic = None
|
|
|
102 |
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
|
103 |
messages = [{"role": "user", "content": prompt}]
|
104 |
return llm_model(messages, stop_sequences=stop)
|
105 |
+
|
106 |
@property
|
107 |
def _llm_type(self) -> str:
|
108 |
return "custom_gemini"
|
109 |
+
|
110 |
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
111 |
gemini_as_llm = GeminiLangChainLLM()
|
112 |
rag_chain = RetrievalQA.from_chain_type(
|