JMalott commited on
Commit
96ce2b8
1 Parent(s): a3f68a6

Update min_dalle/min_dalle.py

Browse files
Files changed (1) hide show
  1. min_dalle/min_dalle.py +22 -22
min_dalle/min_dalle.py CHANGED
@@ -17,7 +17,7 @@ torch.backends.cudnn.enabled = True
17
  torch.backends.cudnn.allow_tf32 = True
18
 
19
  MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
20
- IMAGE_TOKEN_COUNT = 240
21
 
22
 
23
  class MinDalle:
@@ -239,28 +239,27 @@ class MinDalle:
239
  break
240
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
241
 
242
- try:
243
- #torch.cuda.empty_cache()
244
- #torch.cpu.empty_cache()
245
- #with torch.cuda.amp.autocast(dtype=self.dtype):
246
- image_tokens[i + 1], attention_state = self.decoder.forward(
247
- settings=settings,
248
- attention_mask=attention_mask,
249
- encoder_state=encoder_state,
250
- attention_state=attention_state,
251
- prev_tokens=image_tokens[i],
252
- token_index=token_indices[[i]]
 
 
 
 
 
 
 
 
253
  )
254
-
255
- # with torch.cuda.amp.autocast(dtype=torch.float32):
256
- if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
257
- yield self.image_grid_from_tokens(
258
- image_tokens=image_tokens[1:].T,
259
- is_seamless=is_seamless,
260
- is_verbose=is_verbose
261
- )
262
- except Exception as e:
263
- print(e)
264
 
265
  def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
266
  image_stream = self.generate_raw_image_stream(*args, **kwargs)
@@ -277,6 +276,7 @@ class MinDalle:
277
  image = image.transpose(1, 0)
278
  image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
279
  yield image
 
280
 
281
 
282
  def generate_image(self, *args, **kwargs) -> Image.Image:
 
17
  torch.backends.cudnn.allow_tf32 = True
18
 
19
  MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
20
+ IMAGE_TOKEN_COUNT = 256
21
 
22
 
23
  class MinDalle:
 
239
  break
240
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
241
 
242
+
243
+ #torch.cuda.empty_cache()
244
+ #torch.cpu.empty_cache()
245
+ #with torch.cuda.amp.autocast(dtype=self.dtype):
246
+ image_tokens[i + 1], attention_state = self.decoder.forward(
247
+ settings=settings,
248
+ attention_mask=attention_mask,
249
+ encoder_state=encoder_state,
250
+ attention_state=attention_state,
251
+ prev_tokens=image_tokens[i],
252
+ token_index=token_indices[[i]]
253
+ )
254
+
255
+ # with torch.cuda.amp.autocast(dtype=torch.float32):
256
+ if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
257
+ yield self.image_grid_from_tokens(
258
+ image_tokens=image_tokens[1:].T,
259
+ is_seamless=is_seamless,
260
+ is_verbose=is_verbose
261
  )
262
+
 
 
 
 
 
 
 
 
 
263
 
264
  def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
265
  image_stream = self.generate_raw_image_stream(*args, **kwargs)
 
276
  image = image.transpose(1, 0)
277
  image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
278
  yield image
279
+ del image
280
 
281
 
282
  def generate_image(self, *args, **kwargs) -> Image.Image: