File size: 8,158 Bytes
c1295ee
 
 
23a0a97
c1295ee
 
 
 
 
 
 
 
 
 
 
 
 
c675a94
e287ecd
95c34a4
c675a94
c1295ee
 
 
 
 
 
 
95c34a4
 
c1295ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e287ecd
c1295ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9fc8a8
c1295ee
f9fc8a8
c1295ee
 
 
 
 
 
 
 
 
 
 
 
 
95c34a4
c1295ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c675a94
c1295ee
 
 
 
 
 
c675a94
c1295ee
 
c675a94
c1295ee
 
c675a94
c1295ee
 
 
c675a94
 
f9fc8a8
46617f4
c675a94
 
 
 
 
 
f655d6e
e287ecd
c675a94
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import requests
import structlog
import openai
import os
import random
import tiktoken
import enum
import time
import retrying
import IPython.display as display
from base64 import b64decode
import base64
from io import BytesIO
import PIL
import PIL.Image
import PIL.ImageDraw
import PIL.ImageFont
import gradio as gr
import cachetools.func
from huggingface_hub import hf_hub_download


logger = structlog.getLogger()
weather_api_key = os.environ['WEATHER_API']
openai.api_key = os.environ.get("OPENAI_KEY", None)

animals = [x.strip() for x in open('animals.txt').readlines()]
art_styles = [x.strip() for x in open('art_styles.txt').readlines()]
font_path = hf_hub_download("ybelkada/fonts", "Arial.TTF")



class Chat:
    class Model(enum.Enum):
        GPT3_5 = "gpt-3.5-turbo"
        GPT_4  = "gpt-4"

    def __init__(self, system, max_length=4096//2):
        self._system = system
        self._max_length = max_length
        self._history = [
            {"role": "system", "content": self._system},
        ]

    @classmethod
    def num_tokens_from_text(cls, text, model="gpt-3.5-turbo"):
        """Returns the number of tokens used by some text."""
        encoding = tiktoken.encoding_for_model(model)
        return len(encoding.encode(text))
    
    @classmethod
    def num_tokens_from_messages(cls, messages, model="gpt-3.5-turbo"):
        """Returns the number of tokens used by a list of messages."""
        encoding = tiktoken.encoding_for_model(model)
        num_tokens = 0
        for message in messages:
            num_tokens += 4  # every message follows <im_start>{role/name}\n{content}<im_end>\n
            for key, value in message.items():
                num_tokens += len(encoding.encode(value))
                if key == "name":  # if there's a name, the role is omitted
                    num_tokens += -1  # role is always required and always 1 token
        num_tokens += 2  # every reply is primed with <im_start>assistant
        return num_tokens

    @retrying.retry(stop_max_attempt_number=5, wait_fixed=2000)
    def _msg(self, *args, model=Model.GPT3_5.value, **kwargs):
        return openai.ChatCompletion.create(
            *args,
            model=model,
            messages=self._history,
            **kwargs
        )
    
    def message(self, next_msg=None, **kwargs):
        # TODO: Optimize this if slow through easy caching
        while len(self._history) > 1 and self.num_tokens_from_messages(self._history) > self._max_length:
            logger.info(f'Popping message: {self._history.pop(1)}')
        if next_msg is not None:
            self._history.append({"role": "user", "content": next_msg})
        logger.info('requesting openai...')
        resp = self._msg(**kwargs)
        logger.info('received openai...')
        text = resp.choices[0].message.content
        self._history.append({"role": "assistant", "content": text})
        return text

class Weather:
    def __init__(self, zip_code='10001', api_key=weather_api_key):
        self.zip_code = zip_code
        self.api_key = api_key

    def get_weather(self):
        url = f"https://api.weatherapi.com/v1/forecast.json?q={self.zip_code}&days=1&lang=en&aqi=yes&key={self.api_key}"
        headers = {'accept': 'application/json'}
        return requests.get(url, headers=headers).json()

    @cachetools.func.ttl_cache(maxsize=128, ttl=15*60)
    def get_info(self):
        weather = self.get_weather()
        curr_hour = None
        next_hour = None
        for hour_data in weather['forecast']['forecastday'][0]["hour"]:
            if abs(hour_data["time_epoch"] - time.time()) < 60 * 60:
                if curr_hour is None: curr_hour = hour_data
                next_hour = hour_data
        return {
            "now": weather["current"],
            "day": weather["forecast"]["forecastday"][0]["day"],
            "curr_hour": curr_hour,
            "next_hour": next_hour,
        }


class Image:
    class Size(enum.Enum):
        SMALL = "256x256"
        MEDIUM = "512x512"
        LARGE = "1024x1024"

    @classmethod
    @retrying.retry(stop_max_attempt_number=5, wait_fixed=2000)
    def create(cls, prompt, n=1, size=Size.SMALL):
        logger.info('requesting openai.Image...')
        resp = openai.Image.create(prompt=prompt, n=n, size=size.value, response_format='b64_json')
        logger.info('received openai.Image...')
        if n == 1: return resp["data"][0]
        return resp["data"]


def overlay_text_on_image(img, text, position, text_color=(255, 255, 255), box_color=(0, 0, 0, 128), decode=False):
    # Convert the base64 string back to an image
    if decode:
        img_bytes = base64.b64decode(img)
        img = PIL.Image.open(BytesIO(img_bytes))

    # Get image dimensions
    img_width, img_height = img.size

    # Create a ImageDraw object
    draw = PIL.ImageDraw.Draw(img)
    
    # Reduce the font size until it fits the image width or height
    l, r = 1, 50
    while l < r:
        font_size = (l + r) // 2
        font = PIL.ImageFont.truetype(font_path, font_size)
        left, upper, right, lower = draw.textbbox((0, 0), text, font=font)
        text_width = right - left
        text_height = lower - upper
        if text_width <= img_width and text_height <= img_height:
            l = font_size + 1
        else:
            r = font_size - 1
    font_size = max(l-1, 1)

    text_width, text_height = draw.textsize(text, font=font)

    if position == 'top-left':
        x, y = 0, 0
    elif position == 'top-right':
        x, y = img_width - text_width, 0
    elif position == 'bottom-left':
        x, y = 0, img_height - text_height
    elif position == 'bottom-right':
        x, y = img_width - text_width, img_height - text_height
    else:
        raise ValueError("Position should be 'top-left', 'top-right', 'bottom-left' or 'bottom-right'.")

    # Draw a semi-transparent box around the text
    draw.rectangle([x, y, x + text_width, y + text_height], fill=box_color)

    # Draw the text on the image
    draw.text((x, y), text, font=font, fill=text_color)

    return img


class WeatherDraw:
    def clean_text(self, weather_info):
        chat = Chat("Given the following weather conditions, write a very small, concise plaintext summary that will overlay on top of an image.")
        text = chat.message(str(weather_info))
        return text

    def generate_image(self, weather_info, **kwargs):
        animal = random.choice(animals)
        logger.info(f"Got animal {animal}")
        chat = Chat(f'''
Given the following weather conditions, write a plaintext, short, and vivid description of an
adorable {animal} in the weather conditions doing a human activity matching the weather.
Only write the short description and nothing else.
Do not include specific numbers.'''.replace('\n', ' '))
        description = chat.message(str(weather_info))
        prompt = f'{description}. Adorable, cute, 4k, Award winning, in the style of {random.choice(art_styles)}'
        logger.info(prompt)
        img = Image.create(prompt, **kwargs)
        return img["b64_json"], prompt

    def step_one_forecast(self, weather_info, **kwargs):
        img, txt = self.generate_image(weather_info, **kwargs)
        # text = self.clean_text(weather_info)
        # return overlay_text_on_image(img, text, 'bottom-left')
        return img, txt

    def step(self, zip_code='10001', **kwargs):
        forecast = Weather(zip_code).get_info()
        images, texts = [], []
        for time, data in forecast.items():
            img, txt = self.step_one_forecast(data, **kwargs)
            images.append(overlay_text_on_image(img, time, 'top-right', decode=True))
            texts.append(txt)
        return create_collage(*images), *texts


# Define Gradio interface
iface = gr.Interface(fn=WeatherDraw().step, 
                     inputs=gr.inputs.Textbox(label="Enter Zipcode"), 
                     outputs=[gr.outputs.Image(type='pil'), "text", "text", "text", "text"],
                     title="US Zipcode Weather",
                     description="Enter a US Zipcode and get some weather.")

# Run the interface
iface.launch()