|
import os |
|
from typing import Optional, Tuple |
|
|
|
import gradio as gr |
|
from threading import Lock |
|
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.prompts.chat import ( |
|
ChatPromptTemplate, |
|
SystemMessagePromptTemplate, |
|
HumanMessagePromptTemplate) |
|
from langchain.chains import LLMChain |
|
from langchain.chat_models import ChatOpenAI |
|
|
|
|
|
|
|
def load_chain(): |
|
"""Logic for loading the chain you want to use should go here.""" |
|
|
|
template = os.getenv("CAPTION_PROMPT") |
|
system_message_prompt = SystemMessagePromptTemplate.from_template(template) |
|
human_template = "{text}" |
|
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) |
|
|
|
chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) |
|
|
|
|
|
llm = ChatOpenAI(temperature=0.2, |
|
model_name='gpt-3.5-turbo') |
|
chain = LLMChain(llm=llm, prompt=chat_prompt) |
|
return chain |
|
|
|
|
|
def set_openai_api_key(api_key: str): |
|
"""Set the api key and return chain. |
|
|
|
If no api_key, then None is returned. |
|
""" |
|
if api_key: |
|
os.environ["OPENAI_API_KEY"] = api_key |
|
chain = load_chain() |
|
os.environ["OPENAI_API_KEY"] = "" |
|
return chain |
|
|
|
|
|
from transformers import pipeline |
|
captioner = pipeline("image-to-text",model="Salesforce/blip-image-captioning-base") |
|
|
|
import PIL |
|
|
|
|
|
|
|
|
|
|
|
|
|
async def image_supplied(img, count: int, last_cap: str): |
|
if img is None: return "", count, "" |
|
count += 1 |
|
if (count & 1) == 0: return last_cap, count, last_cap |
|
if img.any(): |
|
im = PIL.Image.fromarray(img) |
|
caption = captioner(im, max_new_tokens=20) |
|
result = caption[0]['generated_text'] |
|
|
|
return result, count, result |
|
|
|
|
|
class ChatWrapper: |
|
|
|
def __init__(self): |
|
self.lock = Lock() |
|
|
|
def __call__( |
|
self, api_key: str, inp: str, |
|
chain: Optional[LLMChain] |
|
): |
|
"""Execute the chat functionality.""" |
|
self.lock.acquire() |
|
try: |
|
|
|
|
|
|
|
if chain is None: |
|
|
|
key = openai_api_key_textbox.value |
|
|
|
chain = set_openai_api_key(key) |
|
|
|
|
|
if chain is None: |
|
|
|
|
|
|
|
return "Please paste your OpenAI key to use" |
|
|
|
|
|
import openai |
|
openai.api_key = api_key |
|
openai.api_type = 'open_ai' |
|
openai.api_base = 'https://api.openai.com/v1' |
|
|
|
|
|
output = chain.run(inp) |
|
|
|
last = output |
|
except Exception as e: |
|
raise e |
|
finally: |
|
self.lock.release() |
|
return last |
|
|
|
chat = ChatWrapper() |
|
|
|
|
|
css = """ |
|
.gradio-container {background-color: lightgray; background: url('file=./sd1.png'); background-size: cover} |
|
footer {visibility: hidden} |
|
""" |
|
|
|
font_name = "Kalam" |
|
block = gr.Blocks(title="π· PunnyPix πΈ", css=css, |
|
theme=gr.themes.Default( |
|
text_size = 'lg', |
|
font=[gr.themes.GoogleFont(font_name),"Arial","sans-serif"], |
|
spacing_size="sm", radius_size="sm")) |
|
|
|
|
|
with block: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("<h2><center>π· PunnyPix πΈ</center></h2>") |
|
gr.Markdown("<h4><center>Load image. Edit automated caption. Click 'Submit' to get a funny (hopefully) caption.</center></h4>") |
|
|
|
openai_api_key_textbox = gr.Textbox( |
|
label="π Default key is rate limited. Paste your OpenAI API key (sk-...)", |
|
placeholder="Paste your OpenAI API key (sk-...)", |
|
value = os.getenv("OPENAI_API_KEY"), |
|
lines=1, |
|
type="password" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_box = gr.Image(show_label=False) |
|
|
|
with gr.Row(): |
|
result_box = gr.Textbox(label="Original caption π¨οΈ", value="", interactive=True, lines=1, scale=3) |
|
submit = gr.Button(value="Submit", variant="secondary", size='sm', scale=1) |
|
caption_box = gr.Textbox( |
|
label="Converted caption π―οΈ", |
|
value="", |
|
lines=1, |
|
interactive=False, |
|
scale=3 |
|
) |
|
|
|
gr.Examples( |
|
label="Sample images", |
|
examples=[ |
|
'carolina.jpg', |
|
'house.jpg', |
|
'viceroy.jpg', |
|
'airplane.jpg', |
|
'swimming.jpg', |
|
'cats2.jpg', |
|
'car.jpg', |
|
'dogs.jpg', |
|
'cows2.jpg', |
|
'mountains.jpg' |
|
], |
|
inputs=image_box |
|
) |
|
|
|
gr.HTML( |
|
"<center><a style='color: white', href='https://github.com/flobbit1/punnypix'>Powered by LangChain π¦οΈπ, Hugging Face transformers, OpenAI</a></center>" |
|
) |
|
|
|
|
|
agent_state = gr.State() |
|
|
|
|
|
|
|
submit.click(chat, inputs=[openai_api_key_textbox, result_box, agent_state], outputs=[caption_box]) |
|
result_box.submit(chat, inputs=[openai_api_key_textbox, result_box, agent_state], outputs=[caption_box]) |
|
|
|
|
|
|
|
count = gr.State(value=0) |
|
last_cap = gr.State(value="") |
|
|
|
|
|
image_box.change( |
|
image_supplied, |
|
inputs=[image_box, count, last_cap], |
|
outputs=[result_box, count, last_cap] |
|
) |
|
|
|
|
|
openai_api_key_textbox.change( |
|
set_openai_api_key, |
|
inputs=[openai_api_key_textbox], |
|
outputs=[agent_state], |
|
) |
|
|
|
block.launch(debug=True) |