Spaces:
Runtime error
Runtime error
Update min_dalle/min_dalle.py
Browse files- min_dalle/min_dalle.py +2 -4
min_dalle/min_dalle.py
CHANGED
@@ -237,13 +237,13 @@ class MinDalle:
|
|
237 |
for i in range(IMAGE_TOKEN_COUNT):
|
238 |
if(st.session_state.page != 0):
|
239 |
break
|
|
|
240 |
st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
|
241 |
-
|
242 |
|
243 |
#torch.cuda.empty_cache()
|
244 |
#torch.cpu.empty_cache()
|
245 |
#with torch.cuda.amp.autocast(dtype=self.dtype):
|
246 |
-
image_tokens[i + 1], attention_state = self.decoder.forward(
|
247 |
settings=settings,
|
248 |
attention_mask=attention_mask,
|
249 |
encoder_state=encoder_state,
|
@@ -276,8 +276,6 @@ class MinDalle:
|
|
276 |
image = image.transpose(1, 0)
|
277 |
image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
|
278 |
yield image
|
279 |
-
del image
|
280 |
-
|
281 |
|
282 |
def generate_image(self, *args, **kwargs) -> Image.Image:
|
283 |
image_stream = self.generate_image_stream(
|
|
|
237 |
for i in range(IMAGE_TOKEN_COUNT):
|
238 |
if(st.session_state.page != 0):
|
239 |
break
|
240 |
+
|
241 |
st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
|
|
|
242 |
|
243 |
#torch.cuda.empty_cache()
|
244 |
#torch.cpu.empty_cache()
|
245 |
#with torch.cuda.amp.autocast(dtype=self.dtype):
|
246 |
+
del image_tokens[i + 1], attention_state = self.decoder.forward(
|
247 |
settings=settings,
|
248 |
attention_mask=attention_mask,
|
249 |
encoder_state=encoder_state,
|
|
|
276 |
image = image.transpose(1, 0)
|
277 |
image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
|
278 |
yield image
|
|
|
|
|
279 |
|
280 |
def generate_image(self, *args, **kwargs) -> Image.Image:
|
281 |
image_stream = self.generate_image_stream(
|