JMalott commited on
Commit
05397a1
1 Parent(s): d1ba29f

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +30 -26
utils.py CHANGED
@@ -66,7 +66,6 @@ def layout(*args):
66
  st.markdown(str(foot), unsafe_allow_html=True)
67
 
68
 
69
-
70
  def footer():
71
  myargs = [
72
  "Created by ",
@@ -96,7 +95,6 @@ def footer():
96
  height=600,
97
  )
98
 
99
-
100
  model = False
101
  def generate(prompt,crazy,k):
102
  global model
@@ -113,7 +111,11 @@ def generate(prompt,crazy,k):
113
  set_seed(np.random.randint(0,10000))
114
 
115
  # Sampling
116
- images = model.sampling(prompt=prompt,
 
 
 
 
117
  top_k=2048,
118
  top_p=None,
119
  softmax_temperature=crazy,
@@ -124,7 +126,7 @@ def generate(prompt,crazy,k):
124
  # CLIP Re-ranking
125
  model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
126
  model_clip.to(device=device)
127
- rank = clip_score(prompt=prompt,
128
  images=images,
129
  model_clip=model_clip,
130
  preprocess_clip=preprocess_clip,
@@ -143,35 +145,37 @@ def generate(prompt,crazy,k):
143
 
144
  def drawGrid():
145
  master = {}
146
- order = 0
147
-
148
- #print(st.session_state.results)
149
 
150
  for r in st.session_state.results[::-1]:
151
  _txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k'])
152
-
153
  if(_txt not in master):
154
  master[_txt] = [r]
155
- order += 1
156
  else:
157
  master[_txt].append(r)
158
 
159
-
160
- for m in master:
161
- #with placeholder.container():
 
 
 
 
 
 
162
 
163
- txt = master[m][0]['prompt']+" (temperature:"+ str(master[m][0]['crazy']) + ", top k:" + str(master[m][0]['k']) + ")"
164
- st.subheader(txt)
165
- col1, col2, col3 = st.columns(3)
166
-
167
- for ix, item in enumerate(master[m]):
168
- if ix % 3 == 0:
169
- with col1:
170
- st.image(item["image"])
171
- if ix % 3 == 1:
172
- with col2:
173
- st.image(item["image"])
174
- if ix % 3 == 2:
175
- with col3:
176
- st.image(item["image"])
 
177
 
 
66
  st.markdown(str(foot), unsafe_allow_html=True)
67
 
68
 
 
69
  def footer():
70
  myargs = [
71
  "Created by ",
 
95
  height=600,
96
  )
97
 
 
98
  model = False
99
  def generate(prompt,crazy,k):
100
  global model
 
111
  set_seed(np.random.randint(0,10000))
112
 
113
  # Sampling
114
+ newPrompt = prompt
115
+ if("architecture" not in prompt.lower() ):
116
+ newPrompt += " architecture"
117
+
118
+ images = model.sampling(prompt=newPrompt,
119
  top_k=2048,
120
  top_p=None,
121
  softmax_temperature=crazy,
 
126
  # CLIP Re-ranking
127
  model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
128
  model_clip.to(device=device)
129
+ rank = clip_score(prompt=newPrompt,
130
  images=images,
131
  model_clip=model_clip,
132
  preprocess_clip=preprocess_clip,
 
145
 
146
  def drawGrid():
147
  master = {}
 
 
 
148
 
149
  for r in st.session_state.results[::-1]:
150
  _txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k'])
 
151
  if(_txt not in master):
152
  master[_txt] = [r]
 
153
  else:
154
  master[_txt].append(r)
155
 
156
+
157
+ for i in st.session_state.images:
158
+ im = st.empty()
159
+
160
+
161
+ placeholder = st.empty()
162
+ with placeholder.container():
163
+
164
+ for m in master:
165
 
166
+ txt = master[m][0]['prompt']+" (temperature:"+ str(master[m][0]['crazy']) + ", top k:" + str(master[m][0]['k']) + ")"
167
+ st.subheader(txt)
168
+ col1, col2, col3 = st.columns(3)
169
+
170
+ for ix, item in enumerate(master[m]):
171
+ if ix % 3 == 0:
172
+ with col1:
173
+ st.session_state.images.append(st.image(item["image"]))
174
+ if ix % 3 == 1:
175
+ with col2:
176
+ st.session_state.images.append(st.image(item["image"]))
177
+ if ix % 3 == 2:
178
+ with col3:
179
+ st.session_state.images.append(st.image(item["image"]))
180
+
181