punnypix / app.py
flobbit's picture
Update app.py
31cd7f0
import os
from typing import Optional, Tuple
import gradio as gr
from threading import Lock
# possibly needed for loading the environment variables locally.
# not needed when hosted on hugging face if using HF secrets
#from dotenv import load_dotenv
#load_dotenv()
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate)
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
# initialize the LLM as part of the conversatiion chain
# temperature of 0.2 produces more creativity
def load_chain():
"""Logic for loading the chain you want to use should go here."""
template = os.getenv("CAPTION_PROMPT")
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
human_template = "{text}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
#print(f"chat_prompt={chat_prompt}")
llm = ChatOpenAI(temperature=0.2,
model_name='gpt-3.5-turbo')
chain = LLMChain(llm=llm, prompt=chat_prompt)
return chain
# set the api key and load conversation chain once when the api key changes in input box
def set_openai_api_key(api_key: str):
"""Set the api key and return chain.
If no api_key, then None is returned.
"""
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
chain = load_chain()
os.environ["OPENAI_API_KEY"] = ""
return chain
# load the hugging face image to text captioner
from transformers import pipeline
captioner = pipeline("image-to-text",model="Salesforce/blip-image-captioning-base")
import PIL
# an image has been selected. it comes to this fn as a numpy ndarray
# convert it to a PIL image and feed to the captioner
# return the resulting caption
#
# NOTE: due to an error in gradio, this fn is triggered twice for each img change
# only process the first fn call, and keep the first caption result
async def image_supplied(img, count: int, last_cap: str):
if img is None: return "", count, ""
count += 1
if (count & 1) == 0: return last_cap, count, last_cap
if img.any():
im = PIL.Image.fromarray(img)
caption = captioner(im, max_new_tokens=20)
result = caption[0]['generated_text']
#print(f"caption={result}")
return result, count, result
# class wrapping the chat
class ChatWrapper:
def __init__(self):
self.lock = Lock()
def __call__(
self, api_key: str, inp: str, #history: Optional[Tuple[str, str]],
chain: Optional[LLMChain]
):
"""Execute the chat functionality."""
self.lock.acquire()
try:
#history = history or []
# If chain is None, that is because no API key was provided by user.
if chain is None:
# attempt to load default rate limited key and initialize chain
key = openai_api_key_textbox.value
#print(key)
chain = set_openai_api_key(key)
# if chain is still None, the supplied key didn't work
if chain is None:
#history.append((inp, "Please paste your OpenAI key to use"))
#last = history[-1][-1] # get last element as message returned
#return last, history
return "Please paste your OpenAI key to use"
# Set OpenAI key
import openai
openai.api_key = api_key
openai.api_type = 'open_ai'
openai.api_base = 'https://api.openai.com/v1'
# Run chain and append input.
output = chain.run(inp)
#history.append((inp, output))
last = output #history[-1][-1] # get last element of list, and then last of that vector
except Exception as e:
raise e
finally:
self.lock.release()
return last #, history
chat = ChatWrapper()
# custom css
css = """
.gradio-container {background-color: lightgray; background: url('file=./sd1.png'); background-size: cover}
footer {visibility: hidden}
"""
font_name = "Kalam"
block = gr.Blocks(title="πŸ“· PunnyPix πŸ“Έ", css=css,
theme=gr.themes.Default(
text_size = 'lg',
font=[gr.themes.GoogleFont(font_name),"Arial","sans-serif"],
spacing_size="sm", radius_size="sm"))
# create app layout
with block:
with gr.Row():
with gr.Column():
gr.Markdown("<h2><center>πŸ“· PunnyPix πŸ“Έ</center></h2>")
gr.Markdown("<h4><center>Load image. Edit automated caption. Click 'Submit' to get a funny (hopefully) caption.</center></h4>")
openai_api_key_textbox = gr.Textbox(
label="πŸ”‘ Default key is rate limited. Paste your OpenAI API key (sk-...)",
placeholder="Paste your OpenAI API key (sk-...)",
value = os.getenv("OPENAI_API_KEY"), # default to rate limited key
lines=1,
type="password"
)
with gr.Row():
with gr.Column():
image_box = gr.Image(show_label=False)
with gr.Row():
result_box = gr.Textbox(label="Original caption πŸ—¨οΈ", value="", interactive=True, lines=1, scale=3)
submit = gr.Button(value="Submit", variant="secondary", size='sm', scale=1) #scale button at 1/3 size of two text boxes
caption_box = gr.Textbox(
label="Converted caption πŸ—―οΈ",
value="",
lines=1,
interactive=False,
scale=3
)
gr.Examples(
label="Sample images",
examples=[
'carolina.jpg',
'house.jpg',
'viceroy.jpg',
'airplane.jpg',
'swimming.jpg',
'cats2.jpg',
'car.jpg',
'dogs.jpg',
'cows2.jpg',
'mountains.jpg'
],
inputs=image_box
)
gr.HTML(
"<center><a style='color: white', href='https://github.com/flobbit1/punnypix'>Powered by LangChain πŸ¦œοΈπŸ”—, Hugging Face transformers, OpenAI</a></center>"
)
#state = gr.State()
agent_state = gr.State()
# once caption has been confirmed (either through enter in box or hitting "submit")
# pass to the chat to process and get result (which goes into caption_box)
submit.click(chat, inputs=[openai_api_key_textbox, result_box, agent_state], outputs=[caption_box])
result_box.submit(chat, inputs=[openai_api_key_textbox, result_box, agent_state], outputs=[caption_box])
#submit.click(chat, inputs=[openai_api_key_textbox, result_box, state, agent_state], outputs=[caption_box, state])
#result_box.submit(chat, inputs=[openai_api_key_textbox, result_box, state, agent_state], outputs=[caption_box, state])
count = gr.State(value=0)
last_cap = gr.State(value="")
# if image has changed, feed it to "image_supplied", and pass result to "result_box"
image_box.change(
image_supplied,
inputs=[image_box, count, last_cap],
outputs=[result_box, count, last_cap]
)
# if api key in input box has changed, update the key in app
openai_api_key_textbox.change(
set_openai_api_key,
inputs=[openai_api_key_textbox],
outputs=[agent_state],
)
block.launch(debug=True)