Spaces:
Sleeping
Sleeping
import requests | |
import structlog | |
import openai | |
import os | |
import random | |
import tiktoken | |
import enum | |
import time | |
import retrying | |
import IPython.display as display | |
from base64 import b64decode | |
import base64 | |
from io import BytesIO | |
import PIL | |
import PIL.Image | |
import PIL.ImageDraw | |
import PIL.ImageFont | |
import gradio as gr | |
import cachetools.func | |
logger = structlog.getLogger() | |
weather_api_key = os.environ['WEATHER_API'] | |
openai.api_key = os.environ.get("OPENAI_KEY", None) | |
animals = [x.strip() for x in open('animals.txt').readlines()] | |
art_styles = [x.strip() for x in open('art_styles.txt').readlines()] | |
class Chat: | |
class Model(enum.Enum): | |
GPT3_5 = "gpt-3.5-turbo" | |
GPT_4 = "gpt-4" | |
def __init__(self, system, max_length=4096//2): | |
self._system = system | |
self._max_length = max_length | |
self._history = [ | |
{"role": "system", "content": self._system}, | |
] | |
def num_tokens_from_text(cls, text, model="gpt-3.5-turbo"): | |
"""Returns the number of tokens used by some text.""" | |
encoding = tiktoken.encoding_for_model(model) | |
return len(encoding.encode(text)) | |
def num_tokens_from_messages(cls, messages, model="gpt-3.5-turbo"): | |
"""Returns the number of tokens used by a list of messages.""" | |
encoding = tiktoken.encoding_for_model(model) | |
num_tokens = 0 | |
for message in messages: | |
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n | |
for key, value in message.items(): | |
num_tokens += len(encoding.encode(value)) | |
if key == "name": # if there's a name, the role is omitted | |
num_tokens += -1 # role is always required and always 1 token | |
num_tokens += 2 # every reply is primed with <im_start>assistant | |
return num_tokens | |
def _msg(self, *args, model=Model.GPT3_5.value, **kwargs): | |
return openai.ChatCompletion.create( | |
*args, | |
model=model, | |
messages=self._history, | |
**kwargs | |
) | |
def message(self, next_msg=None, **kwargs): | |
# TODO: Optimize this if slow through easy caching | |
while len(self._history) > 1 and self.num_tokens_from_messages(self._history) > self._max_length: | |
logger.info(f'Popping message: {self._history.pop(1)}') | |
if next_msg is not None: | |
self._history.append({"role": "user", "content": next_msg}) | |
logger.info('requesting openai...') | |
resp = self._msg(**kwargs) | |
logger.info('received openai...') | |
text = resp.choices[0].message.content | |
self._history.append({"role": "assistant", "content": text}) | |
return text | |
class Weather: | |
def __init__(self, zip_code='10001', api_key=weather_api_key): | |
self.zip_code = zip_code | |
self.api_key = api_key | |
def get_weather(self): | |
url = f"https://api.weatherapi.com/v1/forecast.json?q={self.zip_code}&days=1&lang=en&aqi=yes&key={self.api_key}" | |
headers = {'accept': 'application/json'} | |
return requests.get(url, headers=headers).json() | |
def get_info(self): | |
weather = self.get_weather() | |
curr_hour = None | |
next_hour = None | |
for hour_data in weather['forecast']['forecastday'][0]["hour"]: | |
if abs(hour_data["time_epoch"] - time.time()) < 60 * 60: | |
if curr_hour is None: curr_hour = hour_data | |
next_hour = hour_data | |
return { | |
"now": weather["current"], | |
"day": weather["forecast"]["forecastday"][0]["day"], | |
"curr_hour": curr_hour, | |
"next_hour": next_hour, | |
} | |
class Image: | |
class Size(enum.Enum): | |
SMALL = "256x256" | |
MEDIUM = "512x512" | |
LARGE = "1024x1024" | |
def create(cls, prompt, n=1, size=Size.SMALL): | |
logger.info('requesting openai.Image...') | |
resp = openai.Image.create(prompt=prompt, n=n, size=size.value, response_format='b64_json') | |
logger.info('received openai.Image...') | |
if n == 1: return resp["data"][0] | |
return resp["data"] | |
def overlay_text_on_image(img, text, position, text_color=(255, 255, 255), box_color=(0, 0, 0, 128), decode=False): | |
# Convert the base64 string back to an image | |
if decode: | |
img_bytes = base64.b64decode(img) | |
img = PIL.Image.open(BytesIO(img_bytes)) | |
# Get image dimensions | |
img_width, img_height = img.size | |
# Create a ImageDraw object | |
draw = PIL.ImageDraw.Draw(img) | |
# Reduce the font size until it fits the image width or height | |
l, r = 1, 50 | |
while l < r: | |
font_size = (l + r) // 2 | |
font = PIL.ImageFont.truetype("/System/Library/Fonts/NewYork.ttf", font_size) | |
left, upper, right, lower = draw.textbbox((0, 0), text, font=font) | |
text_width = right - left | |
text_height = lower - upper | |
if text_width <= img_width and text_height <= img_height: | |
l = font_size + 1 | |
else: | |
r = font_size - 1 | |
font_size = max(l-1, 1) | |
text_width, text_height = draw.textsize(text, font=font) | |
if position == 'top-left': | |
x, y = 0, 0 | |
elif position == 'top-right': | |
x, y = img_width - text_width, 0 | |
elif position == 'bottom-left': | |
x, y = 0, img_height - text_height | |
elif position == 'bottom-right': | |
x, y = img_width - text_width, img_height - text_height | |
else: | |
raise ValueError("Position should be 'top-left', 'top-right', 'bottom-left' or 'bottom-right'.") | |
# Draw a semi-transparent box around the text | |
draw.rectangle([x, y, x + text_width, y + text_height], fill=box_color) | |
# Draw the text on the image | |
draw.text((x, y), text, font=font, fill=text_color) | |
return img | |
class WeatherDraw: | |
def clean_text(self, weather_info): | |
chat = Chat("Given the following weather conditions, write a very small, concise plaintext summary that will overlay on top of an image.") | |
text = chat.message(str(weather_info)) | |
return text | |
def generate_image(self, weather_info, **kwargs): | |
animal = random.choice(animals) | |
logger.info(f"Got animal {animal}") | |
chat = Chat(f''' | |
Given the following weather conditions, write a plaintext, short, and vivid description of an | |
adorable {animal} in the weather conditions doing a human activity matching the weather. | |
Only write the short description and nothing else. | |
Do not include specific numbers.'''.replace('\n', ' ')) | |
description = chat.message(str(weather_info)) | |
prompt = f'{description}. Adorable, cute, 4k, Award winning, in the style of {random.choice(art_styles)}' | |
logger.info(prompt) | |
img = Image.create(prompt, **kwargs) | |
return img["b64_json"], prompt | |
def step_one_forecast(self, weather_info, **kwargs): | |
img, txt = self.generate_image(weather_info, **kwargs) | |
# text = self.clean_text(weather_info) | |
# return overlay_text_on_image(img, text, 'bottom-left') | |
return img, txt | |
def step(self, zip_code='10001', **kwargs): | |
forecast = Weather(zip_code).get_info() | |
images, texts = [], [] | |
for time, data in forecast.items(): | |
img, txt = self.step_one_forecast(data, **kwargs) | |
images.append(overlay_text_on_image(img, time, 'top-right', decode=True)) | |
texts.append(txt) | |
return create_collage(*images), *texts | |
# Define Gradio interface | |
iface = gr.Interface(fn=WeatherDraw().step, | |
inputs=gr.inputs.Textbox(label="Enter Zipcode"), | |
outputs=[gr.outputs.Image(type='pil'), "text", "text", "text", "text"], | |
title="US Zipcode Weather", | |
description="Enter a US Zipcode and get some weather.") | |
# Run the interface | |
iface.launch() |