File size: 1,626 Bytes
cd607b2
 
eac37df
cd607b2
f5ec828
eac37df
 
cd607b2
 
7b856a8
69deff6
 
7b856a8
 
8200c4e
7b856a8
 
 
 
 
 
5b30d27
7b856a8
4e3dc76
8200c4e
 
 
 
 
f14fdcf
4e3dc76
7b856a8
8200c4e
4e3dc76
 
7b856a8
5b30d27
8200c4e
4e3dc76
7b856a8
 
69deff6
 
7b856a8
5b30d27
8200c4e
4e3dc76
 
 
 
 
 
7b856a8
5b30d27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# + tags=["hide_inp"]
desc = """
### Book QA

Chain that does question answering with Hugging Face embeddings. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/gatsby.ipynb)

(Adapted from the [LlamaIndex example](https://github.com/jerryjliu/gpt_index/blob/main/examples/gatsby/TestGatsby.ipynb).)
"""
# -

# $

import datasets
import numpy as np
from minichain import prompt, show, HuggingFaceEmbed, OpenAI, transform

# Load data with embeddings (computed beforehand)

gatsby = datasets.load_from_disk("gatsby")
gatsby.add_faiss_index("embeddings")

# Fast KNN retrieval prompt

@prompt(HuggingFaceEmbed("sentence-transformers/all-mpnet-base-v2"))
def embed(model, inp):
    return model(inp)

@transform()
def get_neighbors(embedding, k=1):
    res = gatsby.get_nearest_examples("embeddings", np.array(embedding), k)
    return res.examples["passages"]

@prompt(OpenAI(), template_file="gatsby.pmpt.tpl")
def ask(model, query, neighbors):
    return model(dict(question=query, docs=neighbors))

def gatsby_q(query):
    n = get_neighbors(embed(query))
    return ask(query, n)


# $


gradio = show(gatsby_q,
              subprompts=[ask],
              examples=["What did Gatsby do before he met Daisy?",
                        "What did the narrator do after getting back to Chicago?"],
              keys={"HF_KEY"},
              description=desc,
              code=open("gatsby.py", "r").read().split("$")[1].strip().strip("#").strip()
              )
if __name__ == "__main__":
    gradio.queue().launch()