File size: 1,909 Bytes
7b856a8 4e3dc76 7b856a8 f5ec828 7b856a8 4e3dc76 7b856a8 4e3dc76 7b856a8 4e3dc76 8200c4e 7b856a8 4e3dc76 8200c4e 4e3dc76 8200c4e 4e3dc76 f5ec828 4e3dc76 8200c4e f5ec828 8200c4e 4e3dc76 7b856a8 8200c4e 4e3dc76 7b856a8 4e3dc76 7b856a8 4e3dc76 5b30d27 7b856a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
desc = """
### Self-Ask
Notebook implementation of the self-ask + Google tool use prompt. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/selfask.ipynb)
(Adapted from [Self-Ask repo](https://github.com/ofirpress/self-ask))
"""
# $
from dataclasses import dataclass, replace
from typing import Optional
from minichain import prompt, show, OpenAI, Google, transform
@dataclass
class State:
question: str
history: str = ""
next_query: Optional[str] = None
final_answer: Optional[str] = None
@prompt(OpenAI(stop="\nIntermediate answer:"),
template_file = "selfask.pmpt.tpl")
def self_ask(model, state):
return model(state)
@transform()
def next_step(ask):
res = ask.split(":", 1)[1]
if out.startswith("Follow up:"):
return replace(state, next_query=res)
elif out.startswith("So the final answer is:"):
return replace(state, final_answer=res)
@prompt(Google())
def google(model, state):
if state.next_query is None:
return ""
return model(state.next_query)
@transform()
def update(state, result):
if not result:
return state
return State(state.question,
state.history + "\nIntermediate answer: " + result + "\n")
def selfask(question):
state = State(question)
for i in range(3):
state = next_step(self_ask(state))
state = update(google(state))
return state
# $
gradio = show(selfask,
examples=["What is the zip code of the city where George Washington was born?"],
subprompts=[self_ask, google] * 3,
description=desc,
code=open("selfask.py", "r").read().split("$")[1].strip().strip("#").strip(),
out_type="json"
)
if __name__ == "__main__":
gradio.queue().launch()
|