Doux Thibault commited on
Commit
8ee218a
1 Parent(s): 42ef85f

augment rag docs

Browse files
Files changed (1) hide show
  1. Modules/rag.py +22 -33
Modules/rag.py CHANGED
@@ -7,26 +7,22 @@ mistral_api_key = os.getenv("MISTRAL_API_KEY")
7
 
8
  from langchain_community.document_loaders import PyPDFLoader
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
- from langchain_community.document_loaders import WebBaseLoader
11
- from langchain_community.vectorstores import Chroma, FAISS
12
- from langchain.chains.combine_documents import create_stuff_documents_chain
13
  from langchain_mistralai import MistralAIEmbeddings
14
  from langchain import hub
15
- from langchain.chains import (
16
- create_history_aware_retriever,
17
- create_retrieval_chain,
18
- )
19
  from typing import Literal
20
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
21
- from langchain_core.pydantic_v1 import BaseModel, Field
22
  from langchain_mistralai import ChatMistralAI
23
- from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
24
- from langchain_community.tools import DuckDuckGoSearchRun
25
  from pathlib import Path
26
 
27
- def load_chunk_persist_pdf() -> Chroma:
 
 
 
28
 
29
- pdf_folder_path = os.path.join(os.getcwd(),Path("data/pdf/"))
30
  documents = []
31
  for file in os.listdir(pdf_folder_path):
32
  if file.endswith('.pdf'):
@@ -44,30 +40,23 @@ def load_chunk_persist_pdf() -> Chroma:
44
  vectorstore.persist()
45
  return vectorstore
46
 
47
- vectorstore = load_chunk_persist_pdf()
48
- retriever = vectorstore.as_retriever()
49
- prompt = hub.pull("rlm/rag-prompt")
50
-
51
-
52
- # Data model
53
- class RouteQuery(BaseModel):
54
- """Route a user query to the most relevant datasource."""
55
 
56
- datasource: Literal["vectorstore", "websearch"] = Field(
57
- ...,
58
- description="Given a user question choose to route it to web search or a vectorstore.",
59
- )
60
-
61
- # LLM with function call
62
  llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
63
 
64
-
65
  prompt = ChatPromptTemplate.from_template(
66
  """
67
  You are a professional AI coach specialized in fitness, bodybuilding and nutrition.
68
  You must adapt to the user : if he is a beginner, use simple words. You are gentle and motivative.
69
  Use the following pieces of retrieved context to answer the question.
70
- If you don't know the answer, just say that you don't know, and to refer to a nutritionist or a doctor.
71
  Use three sentences maximum and keep the answer concise.
72
 
73
  Question: {question}
@@ -77,12 +66,11 @@ prompt = ChatPromptTemplate.from_template(
77
  Answer:
78
  """,
79
  )
80
- from langchain_core.output_parsers import StrOutputParser
81
- from langchain_core.runnables import RunnablePassthrough
82
 
83
  def format_docs(docs):
84
  return "\n\n".join(doc.page_content for doc in docs)
85
 
 
86
 
87
  rag_chain = (
88
  {"context": retriever | format_docs, "question": RunnablePassthrough()}
@@ -92,6 +80,7 @@ rag_chain = (
92
  )
93
 
94
 
95
- # print(rag_chain.invoke("Build a fitness program for me. Be precise in terms of exercises"))
96
 
97
- # print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program"))
 
 
 
7
 
8
  from langchain_community.document_loaders import PyPDFLoader
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain_community.vectorstores import Chroma
 
 
11
  from langchain_mistralai import MistralAIEmbeddings
12
  from langchain import hub
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from langchain_core.runnables import RunnablePassthrough
 
 
15
  from typing import Literal
16
+ from langchain_core.prompts import ChatPromptTemplate
 
17
  from langchain_mistralai import ChatMistralAI
 
 
18
  from pathlib import Path
19
 
20
+ from langchain.retrievers import (
21
+ MergerRetriever,
22
+ )
23
+ def load_chunk_persist_pdf(task) -> Chroma:
24
 
25
+ pdf_folder_path = os.path.join(os.getcwd(),Path(f"data/pdf/{task}"))
26
  documents = []
27
  for file in os.listdir(pdf_folder_path):
28
  if file.endswith('.pdf'):
 
40
  vectorstore.persist()
41
  return vectorstore
42
 
43
+ zero2hero_vectorstore = load_chunk_persist_pdf("zero2hero")
44
+ bodyweight_vectorstore = load_chunk_persist_pdf("bodyweight")
45
+ nutrition_vectorstore = load_chunk_persist_pdf("nutrition")
46
+ workout_vectorstore = load_chunk_persist_pdf("workout")
47
+ zero2hero_retriever = zero2hero_vectorstore.as_retriever()
48
+ nutrition_retriever = nutrition_vectorstore.as_retriever()
49
+ bodyweight_retriever = bodyweight_vectorstore.as_retriever()
50
+ workout_retriever = workout_vectorstore.as_retriever()
51
 
 
 
 
 
 
 
52
  llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
53
 
 
54
  prompt = ChatPromptTemplate.from_template(
55
  """
56
  You are a professional AI coach specialized in fitness, bodybuilding and nutrition.
57
  You must adapt to the user : if he is a beginner, use simple words. You are gentle and motivative.
58
  Use the following pieces of retrieved context to answer the question.
59
+ If you don't know the answer, use your common knowledge.
60
  Use three sentences maximum and keep the answer concise.
61
 
62
  Question: {question}
 
66
  Answer:
67
  """,
68
  )
 
 
69
 
70
  def format_docs(docs):
71
  return "\n\n".join(doc.page_content for doc in docs)
72
 
73
+ retriever = MergerRetriever(retrievers=[zero2hero_retriever, bodyweight_retriever, nutrition_retriever, workout_retriever])
74
 
75
  rag_chain = (
76
  {"context": retriever | format_docs, "question": RunnablePassthrough()}
 
80
  )
81
 
82
 
 
83
 
84
+ print(rag_chain.invoke("What supplement could i buy to improve my sleep?"))
85
+
86
+ # print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program, and a nutrition program"))