File size: 3,554 Bytes
139fefe
 
38ed905
 
 
 
 
139fefe
 
4b4bf28
139fefe
 
38ed905
139fefe
 
 
 
 
4b4bf28
 
 
 
 
 
 
 
 
 
139fefe
 
 
4b4bf28
 
 
 
 
 
 
139fefe
 
 
 
 
8edfef8
139fefe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b4bf28
 
8edfef8
 
 
139fefe
4b4bf28
139fefe
 
8edfef8
 
4b4bf28
8edfef8
 
4b4bf28
 
 
 
 
 
 
8edfef8
4b4bf28
8edfef8
 
 
 
139fefe
 
38ed905
139fefe
 
 
4b4bf28
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document

from climateqa.engine.reformulation import make_reformulation_chain
from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.utils import pass_values, flatten_dict


DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

def _combine_documents(
    docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):

    doc_strings =  []

    for i,doc in enumerate(docs):
        # chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
        chunk_type = "Doc"
        doc_string = f"{chunk_type} {i+1}: " + format_document(doc, document_prompt)
        doc_string = doc_string.replace("\n"," ") 
        doc_strings.append(doc_string)

    return sep.join(doc_strings)


def get_text_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "text"]

def get_image_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "image"]


def make_rag_chain(retriever,llm):


    # Construct the prompt
    prompt = ChatPromptTemplate.from_template(answer_prompt_template)
    prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)

    # ------- CHAIN 0 - Reformulation
    reformulation_chain = make_reformulation_chain(llm)
    reformulation = (
        {"reformulation":reformulation_chain,**pass_values(["audience","query"])}
        | RunnablePassthrough()
        | flatten_dict
    )


    # ------- CHAIN 1
    # Retrieved documents
    find_documents =  {
        "docs": itemgetter("question") | retriever,
        **pass_values(["question","audience","language","query"])
    } | RunnablePassthrough()


    # ------- CHAIN 2
    # Construct inputs for the llm
    input_documents = {
        "context":lambda x : _combine_documents(x["docs"]),
        **pass_values(["question","audience","language"])
    }

    # ------- CHAIN 3
    # Bot answer


    answer_with_docs = {
        "answer": input_documents | prompt | llm | StrOutputParser(),
        **pass_values(["question","audience","language","query","docs"]),
    }

    answer_without_docs = {
        "answer":  prompt_without_docs | llm | StrOutputParser(),
        **pass_values(["question","audience","language","query","docs"]),
    }

    # def has_images(x):
    #     image_docs = [doc for doc in x["docs"] if doc.metadata["chunk_type"]=="image"]
    #     return len(image_docs) > 0
    
    def has_docs(x):
        return len(x["docs"]) > 0

    answer = RunnableBranch(
        (lambda x: has_docs(x), answer_with_docs),
        answer_without_docs,
    )


    # ------- FINAL CHAIN
    # Build the final chain
    rag_chain = reformulation | find_documents | answer

    return rag_chain



def make_illustration_chain(llm):

    prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)

    input_description_images = {
        "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
        **pass_values(["question","audience","language","answer"]),
    }

    illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
    return illustration_chain