File size: 3,110 Bytes
139fefe 38ed905 139fefe 088e816 c3b815e 38ed905 139fefe 4b4bf28 caf1faa 4b4bf28 139fefe 4b4bf28 088e816 139fefe 088e816 4b4bf28 088e816 8edfef8 088e816 139fefe 088e816 139fefe 088e816 49acaf1 088e816 6b43c86 088e816 4b4bf28 088e816 caf1faa c3b815e caf1faa c3b815e caf1faa c3b815e caf1faa c3b815e caf1faa 4b4bf28 c3b815e 481f3b1 c3b815e 481f3b1 c3b815e 481f3b1 c3b815e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from operator import itemgetter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document
from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.chains.prompts import papers_prompt_template
from ..utils import rename_chain, pass_values
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):
doc_strings = []
for i,doc in enumerate(docs):
# chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
chunk_type = "Doc"
if isinstance(doc,str):
doc_formatted = doc
else:
doc_formatted = format_document(doc, document_prompt)
doc_string = f"{chunk_type} {i+1}: " + doc_formatted
doc_string = doc_string.replace("\n"," ")
doc_strings.append(doc_string)
return sep.join(doc_strings)
def get_text_docs(x):
return [doc for doc in x if doc.metadata["chunk_type"] == "text"]
def get_image_docs(x):
return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
def make_rag_chain(llm):
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
chain = ({
"context":lambda x : _combine_documents(x["documents"]),
"query":itemgetter("query"),
"language":itemgetter("language"),
"audience":itemgetter("audience"),
} | prompt | llm | StrOutputParser())
return chain
def make_rag_chain_without_docs(llm):
prompt = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
chain = prompt | llm | StrOutputParser()
return chain
def make_rag_node(llm,with_docs = True):
if with_docs:
rag_chain = make_rag_chain(llm)
else:
rag_chain = make_rag_chain_without_docs(llm)
async def answer_rag(state,config):
print("---- Answer RAG ----")
answer = await rag_chain.ainvoke(state,config)
print(f"\n\nAnswer:\n{answer}")
return {"answer":answer}
return answer_rag
def make_rag_papers_chain(llm):
prompt = ChatPromptTemplate.from_template(papers_prompt_template)
input_documents = {
"context":lambda x : _combine_documents(x["docs"]),
**pass_values(["question","language"])
}
chain = input_documents | prompt | llm | StrOutputParser()
chain = rename_chain(chain,"answer")
return chain
def make_illustration_chain(llm):
prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
input_description_images = {
"images":lambda x : _combine_documents(get_image_docs(x["docs"])),
**pass_values(["question","audience","language","answer"]),
}
illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
return illustration_chain
|