File size: 3,096 Bytes
139fefe 38ed905 139fefe 088e816 38ed905 139fefe 4b4bf28 caf1faa 4b4bf28 139fefe 4b4bf28 088e816 139fefe 088e816 4b4bf28 088e816 8edfef8 088e816 139fefe 088e816 139fefe 088e816 49acaf1 088e816 6b43c86 088e816 4b4bf28 088e816 caf1faa 088e816 caf1faa 088e816 caf1faa 088e816 caf1faa 088e816 caf1faa 4b4bf28 088e816 481f3b1 088e816 481f3b1 088e816 481f3b1 088e816 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from operator import itemgetter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document
from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.chains.prompts import papers_prompt_template
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):
doc_strings = []
for i,doc in enumerate(docs):
# chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
chunk_type = "Doc"
if isinstance(doc,str):
doc_formatted = doc
else:
doc_formatted = format_document(doc, document_prompt)
doc_string = f"{chunk_type} {i+1}: " + doc_formatted
doc_string = doc_string.replace("\n"," ")
doc_strings.append(doc_string)
return sep.join(doc_strings)
def get_text_docs(x):
return [doc for doc in x if doc.metadata["chunk_type"] == "text"]
def get_image_docs(x):
return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
def make_rag_chain(llm):
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
chain = ({
"context":lambda x : _combine_documents(x["documents"]),
"query":itemgetter("query"),
"language":itemgetter("language"),
"audience":itemgetter("audience"),
} | prompt | llm | StrOutputParser())
return chain
def make_rag_chain_without_docs(llm):
prompt = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
chain = prompt | llm | StrOutputParser()
return chain
def make_rag_node(llm,with_docs = True):
if with_docs:
rag_chain = make_rag_chain(llm)
else:
rag_chain = make_rag_chain_without_docs(llm)
async def answer_rag(state,config):
print("---- Answer RAG ----")
answer = await rag_chain.ainvoke(state,config)
print(f"\n\nAnswer:\n{answer}")
return {"answer":answer}
return answer_rag
# def make_rag_papers_chain(llm):
# prompt = ChatPromptTemplate.from_template(papers_prompt_template)
# input_documents = {
# "context":lambda x : _combine_documents(x["docs"]),
# **pass_values(["question","language"])
# }
# chain = input_documents | prompt | llm | StrOutputParser()
# chain = rename_chain(chain,"answer")
# return chain
# def make_illustration_chain(llm):
# prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
# input_description_images = {
# "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
# **pass_values(["question","audience","language","answer"]),
# }
# illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
# return illustration_chain |