darshan8950 commited on
Commit
3e17de8
1 Parent(s): 791971b
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # Uncomment to run on cpu
5
+ import os
6
+ os.environ["JAX_PLATFORM_NAME"] = "cpu"
7
+
8
+ import random
9
+
10
+ import jax
11
+ import flax.linen as nn
12
+ from flax.training.common_utils import shard
13
+ from flax.jax_utils import replicate, unreplicate
14
+
15
+ from transformers import BartTokenizer, FlaxBartForConditionalGeneration
16
+
17
+ from PIL import Image
18
+ import numpy as np
19
+ import matplotlib.pyplot as plt
20
+
21
+ from vqgan_jax.modeling_flax_vqgan import VQModel
22
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
23
+
24
+ # ## CLIP Scoring
25
+ from transformers import CLIPProcessor, FlaxCLIPModel
26
+
27
+ import gradio as gr
28
+
29
+ from dalle_mini.helpers import captioned_strip
30
+
31
+
32
+ DALLE_REPO = 'flax-community/dalle-mini'
33
+ DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
34
+
35
+ VQGAN_REPO = 'flax-community/vqgan_f16_16384'
36
+ VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'
37
+
38
+ tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
39
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
40
+ vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
41
+
42
+ def custom_to_pil(x):
43
+ x = np.clip(x, 0., 1.)
44
+ x = (255*x).astype(np.uint8)
45
+ x = Image.fromarray(x)
46
+ if not x.mode == "RGB":
47
+ x = x.convert("RGB")
48
+ return x
49
+
50
+ def generate(input, rng, params):
51
+ return model.generate(
52
+ **input,
53
+ max_length=257,
54
+ num_beams=1,
55
+ do_sample=True,
56
+ prng_key=rng,
57
+ eos_token_id=50000,
58
+ pad_token_id=50000,
59
+ params=params,
60
+ )
61
+
62
+ def get_images(indices, params):
63
+ return vqgan.decode_code(indices, params=params)
64
+
65
+ p_generate = jax.pmap(generate, "batch")
66
+ p_get_images = jax.pmap(get_images, "batch")
67
+
68
+ bart_params = replicate(model.params)
69
+ vqgan_params = replicate(vqgan.params)
70
+
71
+ clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
72
+ print("Initialize FlaxCLIPModel")
73
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
74
+ print("Initialize CLIPProcessor")
75
+
76
+ def hallucinate(prompt, num_images=64):
77
+ prompt = [prompt] * jax.device_count()
78
+ inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
79
+ inputs = shard(inputs)
80
+
81
+ all_images = []
82
+ for i in range(num_images // jax.device_count()):
83
+ key = random.randint(0, 1e7)
84
+ rng = jax.random.PRNGKey(key)
85
+ rngs = jax.random.split(rng, jax.local_device_count())
86
+ indices = p_generate(inputs, rngs, bart_params).sequences
87
+ indices = indices[:, :, 1:]
88
+
89
+ images = p_get_images(indices, vqgan_params)
90
+ images = np.squeeze(np.asarray(images), 1)
91
+ for image in images:
92
+ all_images.append(custom_to_pil(image))
93
+ return all_images
94
+
95
+ def clip_top_k(prompt, images, k=8):
96
+ inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
97
+ outputs = clip(**inputs)
98
+ logits = outputs.logits_per_text
99
+ scores = np.array(logits[0]).argsort()[-k:][::-1]
100
+ return [images[score] for score in scores]
101
+
102
+ def compose_predictions(images, caption=None):
103
+ increased_h = 0 if caption is None else 48
104
+ w, h = images[0].size[0], images[0].size[1]
105
+ img = Image.new("RGB", (len(images)*w, h + increased_h))
106
+ for i, img_ in enumerate(images):
107
+ img.paste(img_, (i*w, increased_h))
108
+
109
+ if caption is not None:
110
+ draw = ImageDraw.Draw(img)
111
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
112
+ draw.text((20, 3), caption, (255,255,255), font=font)
113
+ return img
114
+
115
+ def top_k_predictions(prompt, num_candidates=32, k=8):
116
+ images = hallucinate(prompt, num_images=num_candidates)
117
+ images = clip_top_k(prompt, images, k=k)
118
+ return images
119
+
120
+ def run_inference(prompt, num_images=32, num_preds=8):
121
+ images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
122
+ predictions = captioned_strip(images)
123
+ output_title = f"""
124
+ <b>{prompt}</b>
125
+ """
126
+ return (output_title, predictions)
127
+
128
+ outputs = [
129
+ gr.outputs.HTML(label=""), # To be used as title
130
+ gr.outputs.Image(label=''),
131
+ ]
132
+
133
+ description = """
134
+ DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
135
+ """
136
+ gr.Interface(run_inference,
137
+ inputs=[gr.inputs.Textbox(label='What do you want to see?')],
138
+ outputs=outputs,
139
+ title='DALL·E mini',
140
+ description=description,
141
+ article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
142
+ layout='vertical',
143
+ theme='huggingface',
144
+ examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
145
+ allow_flagging=False,
146
+ live=False,
147
+ # server_port=8999
148
+ ).launch(share=True)