MohamedRashad commited on
Commit
366fd1c
·
1 Parent(s): 27e1ebb

Enable bf16 in load_infinity function and enhance transform function with type hints and error handling; refactor joint_vi_vae_encode_decode for improved performance and error management

Browse files
Files changed (1) hide show
  1. app.py +158 -67
app.py CHANGED
@@ -102,7 +102,7 @@ def load_infinity(
102
  text_channels=2048,
103
  apply_spatial_patchify=0,
104
  use_flex_attn=False,
105
- bf16=False,
106
  ):
107
  print('[Loading Infinity]')
108
 
@@ -156,45 +156,137 @@ def load_infinity(
156
 
157
  # Initialize random number generator on the correct device
158
  infinity_test.rng = torch.Generator(device=device)
159
-
160
  return infinity_test
161
 
162
- def transform(pil_img, tgt_h, tgt_w):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  width, height = pil_img.size
164
- if width / height <= tgt_w / tgt_h:
165
- resized_width = tgt_w
166
- resized_height = int(tgt_w / (width / height))
167
- else:
168
- resized_height = tgt_h
169
- resized_width = int((width / height) * tgt_h)
170
- pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS)
171
- # crop the center out
172
- arr = np.array(pil_img)
173
- crop_y = (arr.shape[0] - tgt_h) // 2
174
- crop_x = (arr.shape[1] - tgt_w) // 2
175
- im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w])
176
- return im.add(im).add_(-1)
177
-
178
- def joint_vi_vae_encode_decode(vae, image_path, scale_schedule, device, tgt_h, tgt_w):
179
- pil_image = Image.open(image_path).convert('RGB')
180
- inp = transform(pil_image, tgt_h, tgt_w)
181
- inp = inp.unsqueeze(0).to(device)
182
- scale_schedule = [(item[0], item[1], item[2]) for item in scale_schedule]
183
- t1 = time.time()
184
- h, z, _, all_bit_indices, _, infinity_input = vae.encode(inp, scale_schedule=scale_schedule)
185
- t2 = time.time()
186
- recons_img = vae.decode(z)[0]
187
- if len(recons_img.shape) == 4:
188
- recons_img = recons_img.squeeze(1)
189
- print(f'recons: z.shape: {z.shape}, recons_img shape: {recons_img.shape}')
190
- t3 = time.time()
191
- print(f'vae encode takes {t2-t1:.2f}s, decode takes {t3-t2:.2f}s')
192
- recons_img = (recons_img + 1) / 2
193
- recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
194
- gt_img = (inp[0] + 1) / 2
195
- gt_img = gt_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
196
- print(recons_img.shape, gt_img.shape)
197
- return gt_img, recons_img, all_bit_indices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  def load_visual_tokenizer(args):
200
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -219,29 +311,26 @@ def load_visual_tokenizer(args):
219
  return vae
220
 
221
  def load_transformer(vae, args):
 
222
  model_path = args.model_path
223
- if args.checkpoint_type == 'torch':
224
- # copy large model to local; save slim to local; and copy slim to nas; load local slim model
225
  if osp.exists(args.cache_dir):
226
  local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
227
  else:
228
  local_model_path = model_path
 
229
  if args.enable_model_cache:
230
  slim_model_path = model_path.replace('ar-', 'slim-')
231
  local_slim_model_path = local_model_path.replace('ar-', 'slim-')
232
  os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
233
- print(f'model_path: {model_path}, slim_model_path: {slim_model_path}')
234
- print(f'local_model_path: {local_model_path}, local_slim_model_path: {local_slim_model_path}')
235
  if not osp.exists(local_slim_model_path):
236
  if osp.exists(slim_model_path):
237
- print(f'copy {slim_model_path} to {local_slim_model_path}')
238
  shutil.copyfile(slim_model_path, local_slim_model_path)
239
  else:
240
  if not osp.exists(local_model_path):
241
- print(f'copy {model_path} to {local_model_path}')
242
  shutil.copyfile(model_path, local_model_path)
243
  save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
244
- print(f'copy {local_slim_model_path} to {slim_model_path}')
245
  if not osp.exists(slim_model_path):
246
  shutil.copyfile(local_slim_model_path, slim_model_path)
247
  os.remove(local_model_path)
@@ -249,33 +338,35 @@ def load_transformer(vae, args):
249
  slim_model_path = local_slim_model_path
250
  else:
251
  slim_model_path = model_path
252
- print(f'load checkpoint from {slim_model_path}')
253
-
254
- if args.model_type == 'infinity_2b':
255
- kwargs_model = dict(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8) # 2b model
256
- elif args.model_type == 'infinity_layer12':
257
- kwargs_model = dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
258
- elif args.model_type == 'infinity_layer16':
259
- kwargs_model = dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
260
- elif args.model_type == 'infinity_layer24':
261
- kwargs_model = dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
262
- elif args.model_type == 'infinity_layer32':
263
- kwargs_model = dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
264
- elif args.model_type == 'infinity_layer40':
265
- kwargs_model = dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
266
- elif args.model_type == 'infinity_layer48':
267
- kwargs_model = dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
 
 
268
  infinity = load_infinity(
269
- rope2d_each_sa_layer=args.rope2d_each_sa_layer,
270
  rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
271
  use_scale_schedule_embedding=args.use_scale_schedule_embedding,
272
  pn=args.pn,
273
- use_bit_label=args.use_bit_label,
274
- add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
275
- model_path=slim_model_path,
276
- scale_schedule=None,
277
- vae=vae,
278
- device=None,
279
  model_kwargs=kwargs_model,
280
  text_channels=args.text_channels,
281
  apply_spatial_patchify=args.apply_spatial_patchify,
 
102
  text_channels=2048,
103
  apply_spatial_patchify=0,
104
  use_flex_attn=False,
105
+ bf16=True,
106
  ):
107
  print('[Loading Infinity]')
108
 
 
156
 
157
  # Initialize random number generator on the correct device
158
  infinity_test.rng = torch.Generator(device=device)
159
+
160
  return infinity_test
161
 
162
+ def transform(pil_img: PImage.Image, tgt_h: int, tgt_w: int) -> torch.Tensor:
163
+ """
164
+ Transform a PIL image to a tensor with target dimensions while preserving aspect ratio.
165
+
166
+ Args:
167
+ pil_img: PIL Image to transform
168
+ tgt_h: Target height
169
+ tgt_w: Target width
170
+
171
+ Returns:
172
+ torch.Tensor: Normalized tensor image in range [-1, 1]
173
+ """
174
+ if not isinstance(pil_img, PImage.Image):
175
+ raise TypeError("Input must be a PIL Image")
176
+
177
+ if tgt_h <= 0 or tgt_w <= 0:
178
+ raise ValueError("Target dimensions must be positive")
179
+
180
+ # Calculate resize dimensions preserving aspect ratio
181
  width, height = pil_img.size
182
+ scale = min(tgt_w / width, tgt_h / height)
183
+ new_width = int(width * scale)
184
+ new_height = int(height * scale)
185
+
186
+ # Resize using LANCZOS for best quality
187
+ pil_img = pil_img.resize((new_width, new_height), resample=PImage.LANCZOS)
188
+
189
+ # Create center crop
190
+ arr = np.array(pil_img, dtype=np.uint8)
191
+
192
+ # Calculate crop coordinates
193
+ y1 = max(0, (new_height - tgt_h) // 2)
194
+ x1 = max(0, (new_width - tgt_w) // 2)
195
+ y2 = y1 + tgt_h
196
+ x2 = x1 + tgt_w
197
+
198
+ # Crop and convert to tensor
199
+ arr = arr[y1:y2, x1:x2]
200
+
201
+ # Convert to normalized tensor in one step
202
+ return torch.from_numpy(arr.transpose(2, 0, 1)).float().div_(127.5).sub_(1)
203
+
204
+ def joint_vi_vae_encode_decode(
205
+ vae: 'VAEModel', # Type hint would be more specific with actual VAE class
206
+ image_path: str | Path,
207
+ scale_schedule: List[tuple],
208
+ device: torch.device | str,
209
+ tgt_h: int,
210
+ tgt_w: int
211
+ ) -> tuple[np.ndarray, np.ndarray, torch.Tensor]:
212
+ """
213
+ Encode and decode an image using a VAE model with joint visual-infinity processing.
214
+
215
+ Args:
216
+ vae: The VAE model instance
217
+ image_path: Path to input image
218
+ scale_schedule: List of scale tuples for processing
219
+ device: Target device for computation
220
+ tgt_h: Target height for the image
221
+ tgt_w: Target width for the image
222
+
223
+ Returns:
224
+ tuple containing:
225
+ - Original image as numpy array (uint8)
226
+ - Reconstructed image as numpy array (uint8)
227
+ - Bit indices tensor
228
+
229
+ Raises:
230
+ FileNotFoundError: If image file doesn't exist
231
+ RuntimeError: If VAE processing fails
232
+ """
233
+ try:
234
+ # Validate input path
235
+ if not Path(image_path).exists():
236
+ raise FileNotFoundError(f"Image not found at {image_path}")
237
+
238
+ # Load and preprocess image
239
+ pil_image = Image.open(image_path).convert('RGB')
240
+ inp = transform(pil_image, tgt_h, tgt_w)
241
+ inp = inp.unsqueeze(0).to(device)
242
+
243
+ # Normalize scale schedule
244
+ scale_schedule = [(s[0], s[1], s[2]) for s in scale_schedule]
245
+
246
+ # Decide whether to use CPU or GPU
247
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
248
+
249
+ # Time the encoding/decoding operations
250
+ with torch.amp.autocast(device, dtype=torch.bfloat16):
251
+ encode_start = time.perf_counter()
252
+ h, z, _, all_bit_indices, _, _ = vae.encode(
253
+ inp,
254
+ scale_schedule=scale_schedule
255
+ )
256
+ encode_time = time.perf_counter() - encode_start
257
+
258
+ decode_start = time.perf_counter()
259
+ recons_img = vae.decode(z)[0]
260
+ decode_time = time.perf_counter() - decode_start
261
+
262
+ # Process reconstruction
263
+ if recons_img.dim() == 4:
264
+ recons_img = recons_img.squeeze(1)
265
+
266
+ # Log performance metrics
267
+ print(f'VAE encode: {encode_time:.2f}s, decode: {decode_time:.2f}s')
268
+ print(f'Reconstruction shape: {recons_img.shape}, z shape: {z.shape}')
269
+
270
+ # Convert to numpy arrays efficiently
271
+ recons_img = (recons_img.add(1).div(2)
272
+ .permute(1, 2, 0)
273
+ .mul(255)
274
+ .cpu()
275
+ .numpy()
276
+ .astype(np.uint8))
277
+
278
+ gt_img = (inp[0].add(1).div(2)
279
+ .permute(1, 2, 0)
280
+ .mul(255)
281
+ .cpu()
282
+ .numpy()
283
+ .astype(np.uint8))
284
+
285
+ return gt_img, recons_img, all_bit_indices
286
+
287
+ except Exception as e:
288
+ print(f"Error in VAE processing: {str(e)}")
289
+ raise RuntimeError("VAE processing failed") from e
290
 
291
  def load_visual_tokenizer(args):
292
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
311
  return vae
312
 
313
  def load_transformer(vae, args):
314
+ device = "cuda" if torch.cuda.is_available() else "cpu"
315
  model_path = args.model_path
316
+
317
+ if args.checkpoint_type == 'torch':
318
  if osp.exists(args.cache_dir):
319
  local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
320
  else:
321
  local_model_path = model_path
322
+
323
  if args.enable_model_cache:
324
  slim_model_path = model_path.replace('ar-', 'slim-')
325
  local_slim_model_path = local_model_path.replace('ar-', 'slim-')
326
  os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
 
 
327
  if not osp.exists(local_slim_model_path):
328
  if osp.exists(slim_model_path):
 
329
  shutil.copyfile(slim_model_path, local_slim_model_path)
330
  else:
331
  if not osp.exists(local_model_path):
 
332
  shutil.copyfile(model_path, local_model_path)
333
  save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
 
334
  if not osp.exists(slim_model_path):
335
  shutil.copyfile(local_slim_model_path, slim_model_path)
336
  os.remove(local_model_path)
 
338
  slim_model_path = local_slim_model_path
339
  else:
340
  slim_model_path = model_path
341
+ print(f'Loading checkpoint from {slim_model_path}')
342
+ else:
343
+ raise ValueError(f"Unsupported checkpoint_type: {args.checkpoint_type}")
344
+
345
+ model_configs = {
346
+ 'infinity_2b': dict(depth=32, embed_dim=2048, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8),
347
+ 'infinity_layer12': dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
348
+ 'infinity_layer16': dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
349
+ 'infinity_layer24': dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
350
+ 'infinity_layer32': dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
351
+ 'infinity_layer40': dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
352
+ 'infinity_layer48': dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
353
+ }
354
+
355
+ kwargs_model = model_configs.get(args.model_type)
356
+ if kwargs_model is None:
357
+ raise ValueError(f"Unsupported model_type: {args.model_type}")
358
+
359
  infinity = load_infinity(
360
+ rope2d_each_sa_layer=args.rope2d_each_sa_layer,
361
  rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
362
  use_scale_schedule_embedding=args.use_scale_schedule_embedding,
363
  pn=args.pn,
364
+ use_bit_label=args.use_bit_label,
365
+ add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
366
+ model_path=slim_model_path,
367
+ scale_schedule=None,
368
+ vae=vae,
369
+ device=device,
370
  model_kwargs=kwargs_model,
371
  text_channels=args.text_channels,
372
  apply_spatial_patchify=args.apply_spatial_patchify,