JMalott commited on
Commit
5e84d25
1 Parent(s): 03dd743

Update min_dalle/min_dalle.py

Browse files
Files changed (1) hide show
  1. min_dalle/min_dalle.py +6 -6
min_dalle/min_dalle.py CHANGED
@@ -17,7 +17,7 @@ torch.backends.cudnn.enabled = True
17
  torch.backends.cudnn.allow_tf32 = True
18
 
19
  MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
20
- IMAGE_TOKEN_COUNT = 128
21
 
22
 
23
  class MinDalle:
@@ -177,7 +177,7 @@ class MinDalle:
177
  progressive_outputs: bool = False,
178
  is_seamless: bool = False,
179
  temperature: float = 1,
180
- top_k: int = 128,
181
  supercondition_factor: int = 16,
182
  is_verbose: bool = False
183
  ) -> Iterator[FloatTensor]:
@@ -239,8 +239,8 @@ class MinDalle:
239
  break
240
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
241
 
242
- #torch.cuda.empty_cache()
243
- #torch.device('cpu').empty_cache()
244
  with torch.cuda.amp.autocast(dtype=self.dtype):
245
  image_tokens[i + 1], attention_state = self.decoder.forward(
246
  settings=settings,
@@ -252,7 +252,7 @@ class MinDalle:
252
  )
253
 
254
  with torch.cuda.amp.autocast(dtype=torch.float32):
255
- if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 128:
256
  yield self.image_grid_from_tokens(
257
  image_tokens=image_tokens[1:].T,
258
  is_seamless=is_seamless,
@@ -270,7 +270,7 @@ class MinDalle:
270
  image_stream = self.generate_raw_image_stream(*args, **kwargs)
271
  for image in image_stream:
272
  grid_size = kwargs['grid_size']
273
- image = image.view([grid_size * 128, grid_size, 128, 3])
274
  image = image.transpose(1, 0)
275
  image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
276
  yield image
 
17
  torch.backends.cudnn.allow_tf32 = True
18
 
19
  MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
20
+ IMAGE_TOKEN_COUNT = 256
21
 
22
 
23
  class MinDalle:
 
177
  progressive_outputs: bool = False,
178
  is_seamless: bool = False,
179
  temperature: float = 1,
180
+ top_k: int = 256,
181
  supercondition_factor: int = 16,
182
  is_verbose: bool = False
183
  ) -> Iterator[FloatTensor]:
 
239
  break
240
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
241
 
242
+ #torch.cuda.empty_cache()
243
+ #torch.cpu.empty_cache()
244
  with torch.cuda.amp.autocast(dtype=self.dtype):
245
  image_tokens[i + 1], attention_state = self.decoder.forward(
246
  settings=settings,
 
252
  )
253
 
254
  with torch.cuda.amp.autocast(dtype=torch.float32):
255
+ if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
256
  yield self.image_grid_from_tokens(
257
  image_tokens=image_tokens[1:].T,
258
  is_seamless=is_seamless,
 
270
  image_stream = self.generate_raw_image_stream(*args, **kwargs)
271
  for image in image_stream:
272
  grid_size = kwargs['grid_size']
273
+ image = image.view([grid_size * 256, grid_size, 256, 3])
274
  image = image.transpose(1, 0)
275
  image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
276
  yield image