JMalott commited on
Commit
21adb73
1 Parent(s): 96ce2b8

Update min_dalle/min_dalle.py

Browse files
Files changed (1) hide show
  1. min_dalle/min_dalle.py +2 -4
min_dalle/min_dalle.py CHANGED
@@ -237,13 +237,13 @@ class MinDalle:
237
  for i in range(IMAGE_TOKEN_COUNT):
238
  if(st.session_state.page != 0):
239
  break
 
240
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
241
-
242
 
243
  #torch.cuda.empty_cache()
244
  #torch.cpu.empty_cache()
245
  #with torch.cuda.amp.autocast(dtype=self.dtype):
246
- image_tokens[i + 1], attention_state = self.decoder.forward(
247
  settings=settings,
248
  attention_mask=attention_mask,
249
  encoder_state=encoder_state,
@@ -276,8 +276,6 @@ class MinDalle:
276
  image = image.transpose(1, 0)
277
  image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
278
  yield image
279
- del image
280
-
281
 
282
  def generate_image(self, *args, **kwargs) -> Image.Image:
283
  image_stream = self.generate_image_stream(
 
237
  for i in range(IMAGE_TOKEN_COUNT):
238
  if(st.session_state.page != 0):
239
  break
240
+
241
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
 
242
 
243
  #torch.cuda.empty_cache()
244
  #torch.cpu.empty_cache()
245
  #with torch.cuda.amp.autocast(dtype=self.dtype):
246
+ del image_tokens[i + 1], attention_state = self.decoder.forward(
247
  settings=settings,
248
  attention_mask=attention_mask,
249
  encoder_state=encoder_state,
 
276
  image = image.transpose(1, 0)
277
  image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
278
  yield image
 
 
279
 
280
  def generate_image(self, *args, **kwargs) -> Image.Image:
281
  image_stream = self.generate_image_stream(