minichain / agent.py~
srush's picture
srush HF staff
Upload with huggingface_hub
8200c4e
raw
history blame
2.14 kB
# + tags=["hide_inp"]
desc = """
### Gradio Tool
Chain that ask for a command-line question and then runs the bash command. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/bash.ipynb)
(Adapted from LangChain [BashChain](https://langchain.readthedocs.io/en/latest/modules/chains/examples/llm_bash.html))
"""
# -
# $
from minichain import Id, prompt, OpenAIStream
from gradio_tools.tools import StableDiffusionTool, ImageCaptioningTool
@prompt(StableDiffusionTool())
def gen(model, query):
return model(query)
@prompt(ImageCaptioningTool())
def caption(model, img_src):
return model(img_src)
tools = [gen, caption]
@prompt(Id(),
#OpenAIStream(), stream=True,
template_file="agent.pmpt.tpl")
def agent(model, query):
print(model(dict(tools=[(str(tool.backend.__class__), tool.backend.description)
for tool in tools],
input=query
)))
return ("StableDiffusionTool", "Draw a flower")
# out = ""
# for t in model.stream(dict(tools=[(str(tool.backend.__class__), tool.backend.description)
# for tool in tools],
# input=query
# )):
# out += t
# yield out
# lines = out.split("\n")
# response = lines[0].split("?")[1].strip()
# if response == "Yes":
# tool = lines[1].split(":")[1].strip()
# yield tool
@prompt(dynamic=tools)
def selector(model, input):
selector, input = input
if selector == "StableDiffusionTool":
return model.tool(input, tool_num=0)
else:
return model.tool(input, tool_num=1)
def run(query):
select_input = agent(query)
return selector(select_input)
run("make a pic").run()
# $
gradio = show(run,
subprompts=[agent, selector],
examples=['Draw me a flower'],
out_type="markdown",
description=desc
)
if __name__ == "__main__":
gradio.launch()