File size: 7,138 Bytes
015e873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
from typing import Optional, Tuple

import gradio as gr
from threading import Lock

# possibly needed for loading the environment variables locally.
# not needed when hosted on hugging face if using HF secrets
#from dotenv import load_dotenv
#load_dotenv()

from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate)
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI

# initialize the LLM as part of the conversatiion chain
# temperature of 0.2 produces more creativity
def load_chain():
    """Logic for loading the chain you want to use should go here."""

    template = os.getenv("CAPTION_PROMPT")
    system_message_prompt = SystemMessagePromptTemplate.from_template(template)
    human_template = "{text}"
    human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)

    chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
    #print(f"chat_prompt={chat_prompt}")

    llm = ChatOpenAI(temperature=0.2, 
                  model_name='gpt-3.5-turbo')
    chain = LLMChain(llm=llm, prompt=chat_prompt)
    return chain

# set the api key and load conversation chain once when the api key changes in input box
def set_openai_api_key(api_key: str):
    """Set the api key and return chain.

    If no api_key, then None is returned.
    """
    if api_key:
        os.environ["OPENAI_API_KEY"] = api_key
        chain = load_chain()
        os.environ["OPENAI_API_KEY"] = ""
        return chain

# load the hugging face image to text captioner
from transformers import pipeline
captioner = pipeline("image-to-text",model="Salesforce/blip-image-captioning-base")

import PIL
import numpy
# an image has been selected. it comes to this fn as a numpy ndarray
# convert it to a PIL image and feed to the captioner
# return the resulting caption
def image_supplied(img: numpy.ndarray):
    if img is None: return
    if img.any():
        im = PIL.Image.fromarray(img)
        caption = captioner(im, max_new_tokens=20)
        result = caption[0]['generated_text']
        return result

# class wrapping the chat
class ChatWrapper:

    def __init__(self):
        self.lock = Lock()

    def __call__(
        self, api_key: str, inp: str, #history: Optional[Tuple[str, str]], 
        chain: Optional[LLMChain]
    ):
        """Execute the chat functionality."""
        self.lock.acquire()
        try:
            #history = history or []

            # If chain is None, that is because no API key was provided by user.
            if chain is None:
                # attempt to load default rate limited key and initialize chain
                key = openai_api_key_textbox.value
                #print(key)
                chain = set_openai_api_key(key)

                # if chain is still None, the supplied key didn't work
                if chain is None:
                    #history.append((inp, "Please paste your OpenAI key to use"))
                    #last = history[-1][-1] # get last element as message returned
                    #return last, history
                    return "Please paste your OpenAI key to use"
                
            # Set OpenAI key
            import openai
            openai.api_key = api_key
            openai.api_type = 'open_ai'
            openai.api_base = 'https://api.openai.com/v1'

            # Run chain and append input.
            output = chain.run(inp)
            #history.append((inp, output))
            last = output #history[-1][-1] # get last element of list, and then last of that vector
        except Exception as e:
            raise e
        finally:
            self.lock.release()
        return last #, history

chat = ChatWrapper()

# custom css
css = """
    .gradio-container {background-color: lightgray; background: url('file=./sd1.png'); background-size: cover}
    footer {visibility: hidden}
"""

font_name = "Kalam"
block = gr.Blocks(title="πŸ“· PunnyPix πŸ“Έ", css=css,
                    theme=gr.themes.Default(
                        text_size = 'lg',
                      font=[gr.themes.GoogleFont(font_name),"Arial","sans-serif"],
                      spacing_size="sm", radius_size="sm"))

# create app layout
with block:
    with gr.Row():
        with gr.Column():
            gr.Markdown("<h2><center>πŸ“· PunnyPix πŸ“Έ</center></h2>")
            gr.Markdown("<h4><center>Load image. Edit automated caption. Click 'Submit' to get a funny (hopefully) caption.</center></h4>")
             
        openai_api_key_textbox = gr.Textbox(
            label="πŸ”‘ Default key is rate limited. Paste your OpenAI API key (sk-...)",
            placeholder="Paste your OpenAI API key (sk-...)",
            value = os.getenv("OPENAI_API_KEY"),  # default to rate limited key
            lines=1,
            type="password"
        )
            
    with gr.Row():
        with gr.Column():
            image_box = gr.Image(show_label=False)

    with gr.Row():
        result_box = gr.Textbox(label="Original caption πŸ—¨οΈ", value="", interactive=True, lines=1, scale=3)
        submit = gr.Button(value="Submit", variant="secondary", size='sm', scale=1) #scale button at 1/3 size of two text boxes
        caption_box = gr.Textbox(
            label="Converted caption πŸ—―οΈ",
            value="",
            lines=1,
            interactive=False,
            scale=3
        )

    gr.Examples(
        label="Sample images",
        examples=[
            'carolina.jpg',
            'house.jpg',
            'viceroy.jpg',
            'airplane.jpg',
            'swimming.jpg',
            'cats2.jpg',
            'car.jpg',
            'dogs.jpg',
            'cows2.jpg',
            'mountains.jpg'
        ],
        inputs=image_box
    )

    gr.HTML(
        "<center><a style='color: white', href='https://github.com/flobbit1/punnypix'>Powered by LangChain πŸ¦œοΈπŸ”—, Hugging Face transformers, OpenAI</a></center>"
    )

    #state = gr.State()
    agent_state = gr.State()

    # once caption has been confirmed (either through enter in box or hitting "submit")
    # pass to the chat to process and get result (which goes into caption_box)
    submit.click(chat, inputs=[openai_api_key_textbox, result_box, agent_state], outputs=[caption_box])
    result_box.submit(chat, inputs=[openai_api_key_textbox, result_box, agent_state], outputs=[caption_box])
    #submit.click(chat, inputs=[openai_api_key_textbox, result_box, state, agent_state], outputs=[caption_box, state])
    #result_box.submit(chat, inputs=[openai_api_key_textbox, result_box, state, agent_state], outputs=[caption_box, state])

    # if image has changed, feed it to "image_supplied", and pass result to "result_box"
    image_box.change(
        image_supplied,
        inputs=[image_box],
        outputs=[result_box]
    )

    # if api key in input box has changed, update the key in app
    openai_api_key_textbox.change(
        set_openai_api_key,
        inputs=[openai_api_key_textbox],
        outputs=[agent_state],
    )

block.launch(debug=True)