Spaces:
Runtime error
Runtime error
Update min_dalle/min_dalle.py
Browse files- min_dalle/min_dalle.py +6 -6
min_dalle/min_dalle.py
CHANGED
@@ -17,7 +17,7 @@ torch.backends.cudnn.enabled = True
|
|
17 |
torch.backends.cudnn.allow_tf32 = True
|
18 |
|
19 |
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
|
20 |
-
IMAGE_TOKEN_COUNT =
|
21 |
|
22 |
|
23 |
class MinDalle:
|
@@ -177,7 +177,7 @@ class MinDalle:
|
|
177 |
progressive_outputs: bool = False,
|
178 |
is_seamless: bool = False,
|
179 |
temperature: float = 1,
|
180 |
-
top_k: int =
|
181 |
supercondition_factor: int = 16,
|
182 |
is_verbose: bool = False
|
183 |
) -> Iterator[FloatTensor]:
|
@@ -239,8 +239,8 @@ class MinDalle:
|
|
239 |
break
|
240 |
st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
|
241 |
|
242 |
-
#torch.cuda.empty_cache()
|
243 |
-
#torch.
|
244 |
with torch.cuda.amp.autocast(dtype=self.dtype):
|
245 |
image_tokens[i + 1], attention_state = self.decoder.forward(
|
246 |
settings=settings,
|
@@ -252,7 +252,7 @@ class MinDalle:
|
|
252 |
)
|
253 |
|
254 |
with torch.cuda.amp.autocast(dtype=torch.float32):
|
255 |
-
if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 ==
|
256 |
yield self.image_grid_from_tokens(
|
257 |
image_tokens=image_tokens[1:].T,
|
258 |
is_seamless=is_seamless,
|
@@ -270,7 +270,7 @@ class MinDalle:
|
|
270 |
image_stream = self.generate_raw_image_stream(*args, **kwargs)
|
271 |
for image in image_stream:
|
272 |
grid_size = kwargs['grid_size']
|
273 |
-
image = image.view([grid_size *
|
274 |
image = image.transpose(1, 0)
|
275 |
image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
|
276 |
yield image
|
|
|
17 |
torch.backends.cudnn.allow_tf32 = True
|
18 |
|
19 |
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
|
20 |
+
IMAGE_TOKEN_COUNT = 256
|
21 |
|
22 |
|
23 |
class MinDalle:
|
|
|
177 |
progressive_outputs: bool = False,
|
178 |
is_seamless: bool = False,
|
179 |
temperature: float = 1,
|
180 |
+
top_k: int = 256,
|
181 |
supercondition_factor: int = 16,
|
182 |
is_verbose: bool = False
|
183 |
) -> Iterator[FloatTensor]:
|
|
|
239 |
break
|
240 |
st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
|
241 |
|
242 |
+
#torch.cuda.empty_cache()
|
243 |
+
#torch.cpu.empty_cache()
|
244 |
with torch.cuda.amp.autocast(dtype=self.dtype):
|
245 |
image_tokens[i + 1], attention_state = self.decoder.forward(
|
246 |
settings=settings,
|
|
|
252 |
)
|
253 |
|
254 |
with torch.cuda.amp.autocast(dtype=torch.float32):
|
255 |
+
if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
|
256 |
yield self.image_grid_from_tokens(
|
257 |
image_tokens=image_tokens[1:].T,
|
258 |
is_seamless=is_seamless,
|
|
|
270 |
image_stream = self.generate_raw_image_stream(*args, **kwargs)
|
271 |
for image in image_stream:
|
272 |
grid_size = kwargs['grid_size']
|
273 |
+
image = image.view([grid_size * 256, grid_size, 256, 3])
|
274 |
image = image.transpose(1, 0)
|
275 |
image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
|
276 |
yield image
|