File size: 3,110 Bytes
139fefe
 
38ed905
 
 
 
139fefe
088e816
 
c3b815e
 
38ed905
139fefe
 
 
 
 
4b4bf28
 
 
 
 
 
caf1faa
 
 
 
 
4b4bf28
 
 
139fefe
 
 
4b4bf28
 
 
 
 
 
088e816
139fefe
088e816
 
 
 
 
 
 
4b4bf28
088e816
 
 
 
8edfef8
 
088e816
139fefe
088e816
 
 
 
139fefe
088e816
49acaf1
 
088e816
6b43c86
088e816
4b4bf28
088e816
caf1faa
 
 
 
c3b815e
caf1faa
c3b815e
 
 
 
 
caf1faa
c3b815e
 
caf1faa
c3b815e
caf1faa
 
4b4bf28
 
 
 
c3b815e
481f3b1
c3b815e
481f3b1
c3b815e
 
 
 
481f3b1
c3b815e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document

from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.chains.prompts import papers_prompt_template
from ..utils import rename_chain, pass_values


DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

def _combine_documents(
    docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):

    doc_strings =  []

    for i,doc in enumerate(docs):
        # chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
        chunk_type = "Doc"
        if isinstance(doc,str):
            doc_formatted = doc
        else:
            doc_formatted = format_document(doc, document_prompt)
        doc_string = f"{chunk_type} {i+1}: " + doc_formatted
        doc_string = doc_string.replace("\n"," ") 
        doc_strings.append(doc_string)

    return sep.join(doc_strings)


def get_text_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "text"]

def get_image_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "image"]

def make_rag_chain(llm):
    prompt = ChatPromptTemplate.from_template(answer_prompt_template)
    chain = ({
        "context":lambda x : _combine_documents(x["documents"]),
        "query":itemgetter("query"),
        "language":itemgetter("language"),
        "audience":itemgetter("audience"),
    } | prompt | llm | StrOutputParser())
    return chain

def make_rag_chain_without_docs(llm):
    prompt = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
    chain = prompt | llm | StrOutputParser()
    return chain


def make_rag_node(llm,with_docs = True):

    if with_docs:
        rag_chain = make_rag_chain(llm)
    else:
        rag_chain = make_rag_chain_without_docs(llm)

    async def answer_rag(state,config):
        print("---- Answer RAG ----")

        answer = await rag_chain.ainvoke(state,config)
        print(f"\n\nAnswer:\n{answer}")
        return {"answer":answer}

    return answer_rag




def make_rag_papers_chain(llm):

    prompt = ChatPromptTemplate.from_template(papers_prompt_template)
    input_documents = {
        "context":lambda x : _combine_documents(x["docs"]),
        **pass_values(["question","language"])
    }

    chain = input_documents | prompt | llm | StrOutputParser()
    chain = rename_chain(chain,"answer")

    return chain






def make_illustration_chain(llm):

    prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)

    input_description_images = {
        "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
        **pass_values(["question","audience","language","answer"]),
    }

    illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
    return illustration_chain