File size: 1,365 Bytes
139fefe
38ed905
 
 
139fefe
 
38ed905
139fefe
 
 
 
 
 
 
 
 
38ed905
 
 
 
 
 
139fefe
 
 
 
 
 
 
 
 
 
38ed905
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

from langchain.output_parsers.structured import StructuredOutputParser, ResponseSchema
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch

from climateqa.engine.prompts import reformulation_prompt_template
from climateqa.engine.utils import pass_values, flatten_dict


response_schemas = [
    ResponseSchema(name="language", description="The detected language of the input message"),
    ResponseSchema(name="question", description="The reformulated question always in English")
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()

def fallback_default_values(x):
    if x["question"] is None:
        x["question"] = x["query"]
        x["language"] = "english"
    
    return x

def make_reformulation_chain(llm):

    prompt = PromptTemplate(
        template=reformulation_prompt_template,
        input_variables=["query"],
        partial_variables={"format_instructions": format_instructions}
    )

    chain = (prompt | llm.bind(stop=["```"]) | output_parser)

    reformulation_chain = (
        {"reformulation":chain,**pass_values(["query"])}
        | RunnablePassthrough()
        | flatten_dict
        | fallback_default_values
    )


    return reformulation_chain