MohamedRashad commited on
Commit
5c095cd
·
1 Parent(s): e5b7e3e

Refactor image generation process by removing gen_one_img function; streamline prompt encoding and inference steps

Browse files
Files changed (1) hide show
  1. app.py +43 -98
app.py CHANGED
@@ -6,10 +6,8 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
6
 
7
  import os.path as osp
8
  import time
9
- import hashlib
10
  import argparse
11
  import shutil
12
- import re
13
  import random
14
  from pathlib import Path
15
  from typing import List
@@ -22,7 +20,7 @@ import torch.nn.functional as F
22
  from PIL import Image
23
  import PIL.Image as PImage
24
  from torchvision.transforms.functional import to_tensor
25
- from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast
26
  from huggingface_hub import hf_hub_download
27
  import gradio as gr
28
  import spaces
@@ -66,71 +64,6 @@ def encode_prompt(text_tokenizer, text_encoder, prompt):
66
  text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
67
  return text_cond_tuple
68
 
69
- def gen_one_img(
70
- infinity_test,
71
- vae,
72
- text_tokenizer,
73
- text_encoder,
74
- prompt,
75
- cfg_list=[],
76
- tau_list=[],
77
- negative_prompt='',
78
- scale_schedule=None,
79
- top_k=900,
80
- top_p=0.97,
81
- cfg_sc=3,
82
- cfg_exp_k=0.0,
83
- cfg_insertion_layer=-5,
84
- vae_type=0,
85
- gumbel=0,
86
- softmax_merge_topk=-1,
87
- gt_leak=-1,
88
- gt_ls_Bl=None,
89
- g_seed=None,
90
- sampling_per_bits=1,
91
- ):
92
- sstt = time.time()
93
- if not isinstance(cfg_list, list):
94
- cfg_list = [cfg_list] * len(scale_schedule)
95
- if not isinstance(tau_list, list):
96
- tau_list = [tau_list] * len(scale_schedule)
97
- text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt)
98
- if negative_prompt:
99
- negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
100
- else:
101
- negative_label_B_or_BLT = None
102
- print(f'cfg: {cfg_list}, tau: {tau_list}')
103
-
104
- # Set device if not provided
105
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
106
-
107
- # Set autocast dtype based on bf16 and device support
108
- if device == 'cuda' and torch.cuda.is_bf16_supported():
109
- autocast_dtype = torch.bfloat16
110
- else:
111
- autocast_dtype = torch.float32
112
-
113
- torch.cuda.empty_cache()
114
-
115
- with torch.amp.autocast(device_type=device, dtype=autocast_dtype), torch.no_grad():
116
- stt = time.time()
117
- _, _, img_list = infinity_test.autoregressive_infer_cfg(
118
- vae=vae,
119
- scale_schedule=scale_schedule,
120
- label_B_or_BLT=text_cond_tuple, g_seed=g_seed,
121
- B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None,
122
- cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p,
123
- returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False,
124
- cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer,
125
- vae_type=vae_type, softmax_merge_topk=softmax_merge_topk,
126
- ret_img=True, trunk_scale=1000,
127
- gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True,
128
- sampling_per_bits=sampling_per_bits,
129
- )
130
- print(f"cost: {time.time() - sstt}, infinity cost={time.time() - stt}")
131
- img = img_list[0]
132
- return img
133
-
134
  def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_fsdp'):
135
  print('[Save slim model]')
136
  full_ckpt = torch.load(infinity_model_path, map_location=device)
@@ -143,15 +76,15 @@ def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_
143
  print('[Save slim model] done')
144
  return save_file
145
 
146
- def load_tokenizer(t5_path =''):
147
  print('[Loading tokenizer and text encoder]')
148
- text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
149
- text_tokenizer.model_max_length = 512
150
- text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
151
- text_encoder.to('cuda')
152
- text_encoder.eval()
153
- text_encoder.requires_grad_(False)
154
- return text_tokenizer, text_encoder
155
 
156
  def load_infinity(
157
  rope2d_each_sa_layer,
@@ -220,8 +153,8 @@ def load_infinity(
220
  state_dict = torch.load(model_path, map_location=device)
221
  print(infinity_test.load_state_dict(state_dict))
222
 
223
- # Initialize random number generator on the correct device
224
- infinity_test.rng = torch.Generator(device=device)
225
 
226
  return infinity_test
227
 
@@ -444,6 +377,7 @@ text_tokenizer, text_encoder = load_tokenizer(t5_path="google/flan-t5-xl")
444
  vae = load_visual_tokenizer(args)
445
  infinity = load_transformer(vae, args)
446
 
 
447
  # Define the image generation function
448
  @spaces.GPU
449
  def generate_image(prompt, cfg, tau, h_div_w, seed):
@@ -460,31 +394,42 @@ def generate_image(prompt, cfg, tau, h_div_w, seed):
460
  scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
461
  scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
462
 
463
- # Generate the image
464
- generated_image = gen_one_img(
465
- infinity,
466
- vae,
467
- text_tokenizer,
468
- text_encoder,
469
- prompt,
470
- g_seed=seed,
471
- gt_leak=0,
472
- gt_ls_Bl=None,
473
- cfg_list=cfg,
474
- tau_list=tau,
475
- scale_schedule=scale_schedule,
476
- cfg_insertion_layer=[args.cfg_insertion_layer],
477
- vae_type=args.vae_type,
478
- sampling_per_bits=args.sampling_per_bits,
479
- )
480
 
481
- # Convert the image to RGB and uint8
482
- image = generated_image.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
484
  image = np.uint8(image)
485
 
486
  return image
487
-
488
 
489
  markdown_description = """### Instructions:
490
 
 
6
 
7
  import os.path as osp
8
  import time
 
9
  import argparse
10
  import shutil
 
11
  import random
12
  from pathlib import Path
13
  from typing import List
 
20
  from PIL import Image
21
  import PIL.Image as PImage
22
  from torchvision.transforms.functional import to_tensor
23
+ from transformers import AutoTokenizer, T5EncoderModel
24
  from huggingface_hub import hf_hub_download
25
  import gradio as gr
26
  import spaces
 
64
  text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
65
  return text_cond_tuple
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_fsdp'):
68
  print('[Save slim model]')
69
  full_ckpt = torch.load(infinity_model_path, map_location=device)
 
76
  print('[Save slim model] done')
77
  return save_file
78
 
79
+ def load_tokenizer(t5_path=''):
80
  print('[Loading tokenizer and text encoder]')
81
+ tokenizer = AutoTokenizer.from_pretrained(t5_path, legacy=True)
82
+ tokenizer.model_max_length = 512
83
+ encoder = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
84
+ encoder.eval()
85
+ encoder.to("cuda" if torch.cuda.is_available() else "cpu")
86
+ encoder.requires_grad_(False)
87
+ return tokenizer, encoder
88
 
89
  def load_infinity(
90
  rope2d_each_sa_layer,
 
153
  state_dict = torch.load(model_path, map_location=device)
154
  print(infinity_test.load_state_dict(state_dict))
155
 
156
+ # Initialize random number generator on the correct device
157
+ infinity_test.rng = torch.Generator(device=device)
158
 
159
  return infinity_test
160
 
 
377
  vae = load_visual_tokenizer(args)
378
  infinity = load_transformer(vae, args)
379
 
380
+
381
  # Define the image generation function
382
  @spaces.GPU
383
  def generate_image(prompt, cfg, tau, h_div_w, seed):
 
394
  scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
395
  scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
396
 
397
+ # Encode the prompt
398
+ text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
+ # Set device if not provided
401
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
402
+
403
+ # Set autocast dtype based on bf16 and device support
404
+ if device == 'cuda' and torch.cuda.is_bf16_supported():
405
+ autocast_dtype = torch.bfloat16
406
+ else:
407
+ autocast_dtype = torch.float32
408
+
409
+ torch.cuda.empty_cache()
410
+
411
+ with torch.amp.autocast(device_type=device, dtype=autocast_dtype), torch.no_grad():
412
+ _, _, img_list = infinity.autoregressive_infer_cfg(
413
+ vae=vae,
414
+ scale_schedule=scale_schedule,
415
+ label_B_or_BLT=text_cond_tuple, g_seed=seed,
416
+ B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None,
417
+ cfg_sc=3, cfg_list=[cfg] * len(scale_schedule), tau_list=[tau] * len(scale_schedule), top_k=900, top_p=0.97,
418
+ returns_vemb=1, ratio_Bl1=None, gumbel=0, norm_cfg=False,
419
+ cfg_exp_k=0.0, cfg_insertion_layer=[args.cfg_insertion_layer],
420
+ vae_type=args.vae_type, softmax_merge_topk=-1,
421
+ ret_img=True, trunk_scale=1000,
422
+ gt_leak=0, gt_ls_Bl=None, inference_mode=True,
423
+ sampling_per_bits=args.sampling_per_bits,
424
+ )
425
+
426
+ img = img_list[0]
427
+ image = img.cpu().numpy()
428
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
429
  image = np.uint8(image)
430
 
431
  return image
432
+
433
 
434
  markdown_description = """### Instructions:
435