Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,695 Bytes
b177a48 afdb110 b177a48 afdb110 b177a48 afdb110 b177a48 afdb110 1fe8637 b177a48 b257e01 afdb110 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from openai import OpenAI
from PIL import Image
import requests
import io
import os
import base64
class OpenaiModel():
def __init__(self, model_name, model_type):
self.model_name = model_name
self.model_type = model_type
def __call__(self, *args, **kwargs):
if self.model_type == "text2image":
assert "prompt" in kwargs, "prompt is required for text2image model"
client = OpenAI()
if 'Dalle-3' in self.model_name:
client = OpenAI()
response = client.images.generate(
model="dall-e-3",
prompt=kwargs["prompt"],
size="1024x1024",
quality="standard",
n=1,
)
elif 'Dalle-2' in self.model_name:
client = OpenAI()
response = client.images.generate(
model="dall-e-2",
prompt=kwargs["prompt"],
size="512x512",
quality="standard",
n=1,
)
else:
raise NotImplementedError
result_url = response.data[0].url
response = requests.get(result_url)
result = Image.open(io.BytesIO(response.content))
return result
else:
raise ValueError("model_type must be text2image or image2image")
def load_openai_model(model_name, model_type):
return OpenaiModel(model_name, model_type)
if __name__ == "__main__":
pipe = load_openai_model('Dalle-3', 'text2image')
result = pipe(prompt='draw a tiger')
print(result)
|