JMalott commited on
Commit
eda6154
1 Parent(s): cc89b51

Update min_dalle/min_dalle.py

Browse files
Files changed (1) hide show
  1. min_dalle/min_dalle.py +2 -4
min_dalle/min_dalle.py CHANGED
@@ -238,8 +238,6 @@ class MinDalle:
238
  if(st.session_state.page != 0):
239
  break
240
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
241
-
242
- print(i)
243
 
244
  torch.cuda.empty_cache()
245
  with torch.cuda.amp.autocast(dtype=self.dtype):
@@ -252,8 +250,8 @@ class MinDalle:
252
  token_index=token_indices[[i]]
253
  )
254
 
255
- with torch.cuda.amp.autocast(dtype=torch.float32):
256
- if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
257
  yield self.image_grid_from_tokens(
258
  image_tokens=image_tokens[1:].T,
259
  is_seamless=is_seamless,
 
238
  if(st.session_state.page != 0):
239
  break
240
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
 
 
241
 
242
  torch.cuda.empty_cache()
243
  with torch.cuda.amp.autocast(dtype=self.dtype):
 
250
  token_index=token_indices[[i]]
251
  )
252
 
253
+ with torch.cuda.amp.autocast(dtype=torch.float16):
254
+ if ((i + 1) % 16 == 0 and progressive_outputs) or i + 1 == 256:
255
  yield self.image_grid_from_tokens(
256
  image_tokens=image_tokens[1:].T,
257
  is_seamless=is_seamless,