File size: 3,554 Bytes
139fefe 38ed905 139fefe 4b4bf28 139fefe 38ed905 139fefe 4b4bf28 139fefe 4b4bf28 139fefe 8edfef8 139fefe 4b4bf28 8edfef8 139fefe 4b4bf28 139fefe 8edfef8 4b4bf28 8edfef8 4b4bf28 8edfef8 4b4bf28 8edfef8 139fefe 38ed905 139fefe 4b4bf28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from operator import itemgetter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document
from climateqa.engine.reformulation import make_reformulation_chain
from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.utils import pass_values, flatten_dict
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):
doc_strings = []
for i,doc in enumerate(docs):
# chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
chunk_type = "Doc"
doc_string = f"{chunk_type} {i+1}: " + format_document(doc, document_prompt)
doc_string = doc_string.replace("\n"," ")
doc_strings.append(doc_string)
return sep.join(doc_strings)
def get_text_docs(x):
return [doc for doc in x if doc.metadata["chunk_type"] == "text"]
def get_image_docs(x):
return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
def make_rag_chain(retriever,llm):
# Construct the prompt
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
# ------- CHAIN 0 - Reformulation
reformulation_chain = make_reformulation_chain(llm)
reformulation = (
{"reformulation":reformulation_chain,**pass_values(["audience","query"])}
| RunnablePassthrough()
| flatten_dict
)
# ------- CHAIN 1
# Retrieved documents
find_documents = {
"docs": itemgetter("question") | retriever,
**pass_values(["question","audience","language","query"])
} | RunnablePassthrough()
# ------- CHAIN 2
# Construct inputs for the llm
input_documents = {
"context":lambda x : _combine_documents(x["docs"]),
**pass_values(["question","audience","language"])
}
# ------- CHAIN 3
# Bot answer
answer_with_docs = {
"answer": input_documents | prompt | llm | StrOutputParser(),
**pass_values(["question","audience","language","query","docs"]),
}
answer_without_docs = {
"answer": prompt_without_docs | llm | StrOutputParser(),
**pass_values(["question","audience","language","query","docs"]),
}
# def has_images(x):
# image_docs = [doc for doc in x["docs"] if doc.metadata["chunk_type"]=="image"]
# return len(image_docs) > 0
def has_docs(x):
return len(x["docs"]) > 0
answer = RunnableBranch(
(lambda x: has_docs(x), answer_with_docs),
answer_without_docs,
)
# ------- FINAL CHAIN
# Build the final chain
rag_chain = reformulation | find_documents | answer
return rag_chain
def make_illustration_chain(llm):
prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
input_description_images = {
"images":lambda x : _combine_documents(get_image_docs(x["docs"])),
**pass_values(["question","audience","language","answer"]),
}
illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
return illustration_chain |