Spaces:
Running
Running
Update pipeline.py
Browse files- pipeline.py +4 -21
pipeline.py
CHANGED
@@ -32,10 +32,10 @@ gemini_llm = ChatGoogleGenerativeAI(
|
|
32 |
model="gemini-1.5-pro",
|
33 |
temperature=0,
|
34 |
max_retries=2,
|
35 |
-
#
|
36 |
)
|
37 |
|
38 |
-
# Initialize
|
39 |
pydantic_agent = ManagedAgent(
|
40 |
llm=ChatGoogleGenerativeAI(
|
41 |
model="gemini-1.5-pro",
|
@@ -191,29 +191,12 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
|
191 |
except Exception as e:
|
192 |
raise RuntimeError(f"Error building/loading vector store: {str(e)}")
|
193 |
|
194 |
-
class GeminiLangChainLLM(LLM):
|
195 |
-
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
|
196 |
-
"""Call the Gemini model using ChatGoogleGenerativeAI and ensure string output."""
|
197 |
-
try:
|
198 |
-
# Construct message list for the Gemini model
|
199 |
-
messages = [("human", prompt)]
|
200 |
-
ai_msg = gemini_llm.invoke(messages)
|
201 |
-
return ai_msg.content.strip() if ai_msg and ai_msg.content else str(prompt)
|
202 |
-
except Exception as e:
|
203 |
-
print(f"Error in GeminiLangChainLLM._call: {e}")
|
204 |
-
return str(prompt) # Fallback to returning the prompt
|
205 |
-
|
206 |
-
@property
|
207 |
-
def _llm_type(self) -> str:
|
208 |
-
return "custom_gemini"
|
209 |
-
|
210 |
def build_rag_chain(vectorstore: FAISS) -> RetrievalQA:
|
211 |
-
"""Build RAG chain
|
212 |
try:
|
213 |
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
214 |
-
gemini_llm_instance = GeminiLangChainLLM()
|
215 |
chain = RetrievalQA.from_chain_type(
|
216 |
-
llm=
|
217 |
chain_type="stuff",
|
218 |
retriever=retriever,
|
219 |
return_source_documents=True
|
|
|
32 |
model="gemini-1.5-pro",
|
33 |
temperature=0,
|
34 |
max_retries=2,
|
35 |
+
# Additional parameters or safety_settings can be added here if needed
|
36 |
)
|
37 |
|
38 |
+
# Initialize ManagedAgent for web search using Gemini
|
39 |
pydantic_agent = ManagedAgent(
|
40 |
llm=ChatGoogleGenerativeAI(
|
41 |
model="gemini-1.5-pro",
|
|
|
191 |
except Exception as e:
|
192 |
raise RuntimeError(f"Error building/loading vector store: {str(e)}")
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
def build_rag_chain(vectorstore: FAISS) -> RetrievalQA:
|
195 |
+
"""Build RAG chain using the Gemini LLM directly without a custom class."""
|
196 |
try:
|
197 |
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
|
|
198 |
chain = RetrievalQA.from_chain_type(
|
199 |
+
llm=gemini_llm, # Directly use the ChatGoogleGenerativeAI instance
|
200 |
chain_type="stuff",
|
201 |
retriever=retriever,
|
202 |
return_source_documents=True
|