JMalott commited on
Commit
0f96e76
1 Parent(s): dfa8d4a

Update min_dalle/min_dalle.py

Browse files
Files changed (1) hide show
  1. min_dalle/min_dalle.py +17 -22
min_dalle/min_dalle.py CHANGED
@@ -10,12 +10,11 @@ from typing import Iterator
10
  from .text_tokenizer import TextTokenizer
11
  from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
12
  import streamlit as st
13
- import gc
14
 
15
  torch.set_grad_enabled(False)
16
  torch.set_num_threads(os.cpu_count())
17
- torch.backends.cudnn.enabled = False
18
- torch.backends.cudnn.allow_tf16 = False
19
 
20
  MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
21
  IMAGE_TOKEN_COUNT = 256
@@ -25,7 +24,7 @@ class MinDalle:
25
  def __init__(
26
  self,
27
  models_root: str = 'pretrained',
28
- dtype: torch.dtype = torch.float16,
29
  device: str = None,
30
  is_mega: bool = True,
31
  is_reusable: bool = True,
@@ -188,7 +187,7 @@ class MinDalle:
188
  if len(tokens) > self.text_token_count:
189
  tokens = tokens[:self.text_token_count]
190
  if is_verbose: print("{} text tokens".format(len(tokens)), tokens)
191
- text_tokens = numpy.ones((2, 64), dtype=numpy.int16)
192
  text_tokens[0, :2] = [tokens[0], tokens[-1]]
193
  text_tokens[1, :len(tokens)] = tokens
194
  text_tokens = torch.tensor(
@@ -232,37 +231,33 @@ class MinDalle:
232
  token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=self.device)
233
  settings = torch.tensor(
234
  [temperature, top_k, supercondition_factor],
235
- dtype=torch.float16,
236
  device=self.device
237
  )
238
-
239
-
240
  for i in range(IMAGE_TOKEN_COUNT):
241
-
242
  if(st.session_state.page != 0):
243
  break
244
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
245
 
246
- #torch.cuda.empty_cache()
247
  #torch.cpu.empty_cache()
248
- #gc.collect()
249
-
250
- image_tokens[i + 1], attention_state = self.decoder.forward(
251
- settings=settings,
252
- attention_mask=attention_mask,
253
- encoder_state=encoder_state,
254
- attention_state=attention_state,
255
- prev_tokens=image_tokens[i],
256
- token_index=token_indices[[i]]
257
- )
258
 
259
- if ((i + 1) % 16 == 0 and progressive_outputs) or i + 1 == 256:
 
260
  yield self.image_grid_from_tokens(
261
  image_tokens=image_tokens[1:].T,
262
  is_seamless=is_seamless,
263
  is_verbose=is_verbose
264
  )
265
-
266
 
267
  def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
268
  image_stream = self.generate_raw_image_stream(*args, **kwargs)
 
10
  from .text_tokenizer import TextTokenizer
11
  from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
12
  import streamlit as st
 
13
 
14
  torch.set_grad_enabled(False)
15
  torch.set_num_threads(os.cpu_count())
16
+ torch.backends.cudnn.enabled = True
17
+ torch.backends.cudnn.allow_tf32 = True
18
 
19
  MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
20
  IMAGE_TOKEN_COUNT = 256
 
24
  def __init__(
25
  self,
26
  models_root: str = 'pretrained',
27
+ dtype: torch.dtype = torch.float32,
28
  device: str = None,
29
  is_mega: bool = True,
30
  is_reusable: bool = True,
 
187
  if len(tokens) > self.text_token_count:
188
  tokens = tokens[:self.text_token_count]
189
  if is_verbose: print("{} text tokens".format(len(tokens)), tokens)
190
+ text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
191
  text_tokens[0, :2] = [tokens[0], tokens[-1]]
192
  text_tokens[1, :len(tokens)] = tokens
193
  text_tokens = torch.tensor(
 
231
  token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=self.device)
232
  settings = torch.tensor(
233
  [temperature, top_k, supercondition_factor],
234
+ dtype=torch.float32,
235
  device=self.device
236
  )
 
 
237
  for i in range(IMAGE_TOKEN_COUNT):
 
238
  if(st.session_state.page != 0):
239
  break
240
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
241
 
242
+ torch.cuda.empty_cache()
243
  #torch.cpu.empty_cache()
244
+ with torch.cuda.amp.autocast(dtype=self.dtype):
245
+ image_tokens[i + 1], attention_state = self.decoder.forward(
246
+ settings=settings,
247
+ attention_mask=attention_mask,
248
+ encoder_state=encoder_state,
249
+ attention_state=attention_state,
250
+ prev_tokens=image_tokens[i],
251
+ token_index=token_indices[[i]]
252
+ )
 
253
 
254
+ # with torch.cuda.amp.autocast(dtype=torch.float32):
255
+ if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
256
  yield self.image_grid_from_tokens(
257
  image_tokens=image_tokens[1:].T,
258
  is_seamless=is_seamless,
259
  is_verbose=is_verbose
260
  )
 
261
 
262
  def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
263
  image_stream = self.generate_raw_image_stream(*args, **kwargs)