Commit
·
9660022
1
Parent(s):
02ead63
Create class_name
Browse files- class_name +31 -0
class_name
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from io import BytesIO
|
5 |
+
from scipy.stats import truncnorm
|
6 |
+
from skimage.transform import resize
|
7 |
+
from transformers import CLIPProcessor, CLIPModel
|
8 |
+
|
9 |
+
class TextToImageGenerator:
|
10 |
+
def __init__(self):
|
11 |
+
self.clip = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
|
12 |
+
self.processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
|
13 |
+
self.generator = tf.keras.models.load_model('path/to/generator/model')
|
14 |
+
|
15 |
+
def generate_image(self, prompt):
|
16 |
+
encoded_prompt = self.processor(prompt, return_tensors="tf").to_dict()
|
17 |
+
noise = tf.random.normal([1, 256])
|
18 |
+
text_features = self.clip.get_text_features(encoded_prompt)
|
19 |
+
image_features = self.generator([text_features, noise], training=False)[0]
|
20 |
+
image = self._postprocess_image(image_features)
|
21 |
+
return image
|
22 |
+
|
23 |
+
def _postprocess_image(self, image_features):
|
24 |
+
image_features = (image_features + 1) / 2 # scale from [-1, 1] to [0, 1]
|
25 |
+
image_features = np.clip(image_features, 0, 1) # clip any values outside of [0, 1]
|
26 |
+
image = Image.fromarray(np.uint8(image_features * 255))
|
27 |
+
image = image.resize((256, 256))
|
28 |
+
image_buffer = BytesIO()
|
29 |
+
image.save(image_buffer, format='JPEG')
|
30 |
+
image_data = image_buffer.getvalue()
|
31 |
+
return image_data
|