File size: 4,375 Bytes
139fefe 38ed905 139fefe 4b4bf28 caf1faa 37b1e7a caf1faa 38ed905 139fefe 4b4bf28 caf1faa 4b4bf28 139fefe 4b4bf28 139fefe 8edfef8 139fefe 37b1e7a 139fefe caf1faa 139fefe 37b1e7a 139fefe caf1faa 139fefe 4b4bf28 37b1e7a 8edfef8 37b1e7a caf1faa 139fefe 8edfef8 37b1e7a caf1faa 8edfef8 4b4bf28 8edfef8 4b4bf28 8edfef8 139fefe caf1faa 139fefe 4b4bf28 caf1faa 4b4bf28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from operator import itemgetter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document
from climateqa.engine.reformulation import make_reformulation_chain
from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.prompts import papers_prompt_template
from climateqa.engine.utils import pass_values, flatten_dict,prepare_chain,rename_chain
from climateqa.engine.keywords import make_keywords_chain
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):
doc_strings = []
for i,doc in enumerate(docs):
# chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
chunk_type = "Doc"
if isinstance(doc,str):
doc_formatted = doc
else:
doc_formatted = format_document(doc, document_prompt)
doc_string = f"{chunk_type} {i+1}: " + doc_formatted
doc_string = doc_string.replace("\n"," ")
doc_strings.append(doc_string)
return sep.join(doc_strings)
def get_text_docs(x):
return [doc for doc in x if doc.metadata["chunk_type"] == "text"]
def get_image_docs(x):
return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
def make_rag_chain(retriever,llm):
# Construct the prompt
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
# ------- CHAIN 0 - Reformulation
reformulation = make_reformulation_chain(llm)
reformulation = prepare_chain(reformulation,"reformulation")
# ------- Find all keywords from the reformulated query
keywords = make_keywords_chain(llm)
keywords = {"keywords":itemgetter("question") | keywords}
keywords = prepare_chain(keywords,"keywords")
# ------- CHAIN 1
# Retrieved documents
find_documents = {"docs": itemgetter("question") | retriever} | RunnablePassthrough()
find_documents = prepare_chain(find_documents,"find_documents")
# ------- CHAIN 2
# Construct inputs for the llm
input_documents = {
"context":lambda x : _combine_documents(x["docs"]),
**pass_values(["question","audience","language","keywords"])
}
# ------- CHAIN 3
# Bot answer
llm_final = rename_chain(llm,"answer")
answer_with_docs = {
"answer": input_documents | prompt | llm_final | StrOutputParser(),
**pass_values(["question","audience","language","query","docs","keywords"]),
}
answer_without_docs = {
"answer": prompt_without_docs | llm_final | StrOutputParser(),
**pass_values(["question","audience","language","query","docs","keywords"]),
}
# def has_images(x):
# image_docs = [doc for doc in x["docs"] if doc.metadata["chunk_type"]=="image"]
# return len(image_docs) > 0
def has_docs(x):
return len(x["docs"]) > 0
answer = RunnableBranch(
(lambda x: has_docs(x), answer_with_docs),
answer_without_docs,
)
# ------- FINAL CHAIN
# Build the final chain
rag_chain = reformulation | keywords | find_documents | answer
return rag_chain
def make_rag_papers_chain(llm):
prompt = ChatPromptTemplate.from_template(papers_prompt_template)
input_documents = {
"context":lambda x : _combine_documents(x["docs"]),
**pass_values(["question","language"])
}
chain = input_documents | prompt | llm | StrOutputParser()
chain = rename_chain(chain,"answer")
return chain
def make_illustration_chain(llm):
prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
input_description_images = {
"images":lambda x : _combine_documents(get_image_docs(x["docs"])),
**pass_values(["question","audience","language","answer"]),
}
illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
return illustration_chain |