MohamedRashad commited on
Commit
715c7b0
·
1 Parent(s): 920cc4d

Refactor encode_prompt function by removing enable_positive_prompt parameter and related logic; streamline image generation process

Browse files
Files changed (1) hide show
  1. app.py +58 -72
app.py CHANGED
@@ -49,19 +49,7 @@ def download_infinity_weights(weights_path):
49
  except Exception as e:
50
  print(f"Error downloading weights: {e}")
51
 
52
- def extract_key_val(text):
53
- pattern = r'<(.+?):(.+?)>'
54
- matches = re.findall(pattern, text)
55
- key_val = {}
56
- for match in matches:
57
- key_val[match[0]] = match[1].lstrip()
58
- return key_val
59
-
60
- def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=False):
61
- if enable_positive_prompt:
62
- print(f'before positive_prompt aug: {prompt}')
63
- prompt = aug_with_positive_prompt(prompt)
64
- print(f'after positive_prompt aug: {prompt}')
65
  print(f'prompt={prompt}')
66
  captions = [prompt]
67
  tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
@@ -78,14 +66,6 @@ def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=F
78
  text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
79
  return text_cond_tuple
80
 
81
- def aug_with_positive_prompt(prompt):
82
- for key in ['man', 'woman', 'men', 'women', 'boy', 'girl', 'child', 'person', 'human', 'adult', 'teenager', 'employee',
83
- 'employer', 'worker', 'mother', 'father', 'sister', 'brother', 'grandmother', 'grandfather', 'son', 'daughter']:
84
- if key in prompt:
85
- prompt = prompt + '. very smooth faces, good looking faces, face to the camera, perfect facial features'
86
- break
87
- return prompt
88
-
89
  def enhance_image(image):
90
  for t in range(1):
91
  contrast_image = image.copy()
@@ -118,20 +98,31 @@ def gen_one_img(
118
  gt_ls_Bl=None,
119
  g_seed=None,
120
  sampling_per_bits=1,
121
- enable_positive_prompt=0,
122
  ):
123
  sstt = time.time()
124
  if not isinstance(cfg_list, list):
125
  cfg_list = [cfg_list] * len(scale_schedule)
126
  if not isinstance(tau_list, list):
127
  tau_list = [tau_list] * len(scale_schedule)
128
- text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
129
  if negative_prompt:
130
  negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
131
  else:
132
  negative_label_B_or_BLT = None
133
  print(f'cfg: {cfg_list}, tau: {tau_list}')
134
- with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
 
 
 
 
 
 
 
 
 
 
 
 
135
  stt = time.time()
136
  _, _, img_list = infinity_test.autoregressive_infer_cfg(
137
  vae=vae,
@@ -169,7 +160,7 @@ def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_
169
  return save_file
170
 
171
  def load_tokenizer(t5_path =''):
172
- print(f'[Loading tokenizer and text encoder]')
173
  text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
174
  text_tokenizer.model_max_length = 512
175
  text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
@@ -195,7 +186,7 @@ def load_infinity(
195
  use_flex_attn=False,
196
  bf16=False,
197
  ):
198
- print(f'[Loading Infinity]')
199
 
200
  # Set device if not provided
201
  if device is None:
@@ -241,13 +232,13 @@ def load_infinity(
241
  infinity_test.eval()
242
  infinity_test.requires_grad_(False)
243
 
244
- print(f'[Load Infinity weights]')
245
  state_dict = torch.load(model_path, map_location=device)
246
  print(infinity_test.load_state_dict(state_dict))
247
 
248
  # Initialize random number generator on the correct device
249
  infinity_test.rng = torch.Generator(device=device)
250
-
251
  return infinity_test
252
 
253
  def transform(pil_img, tgt_h, tgt_w):
@@ -471,50 +462,45 @@ infinity = load_transformer(vae, args)
471
 
472
  # Define the image generation function
473
  @spaces.GPU
474
- def generate_image(prompt, cfg, tau, h_div_w, seed, enable_positive_prompt=False):
475
- try:
476
- args.prompt = prompt
477
- args.cfg = cfg
478
- args.tau = tau
479
- args.h_div_w = h_div_w
480
- args.seed = seed
481
- args.enable_positive_prompt = enable_positive_prompt
482
-
483
- # Find the closest h_div_w_template
484
- h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates - h_div_w))]
485
-
486
- # Get scale_schedule based on h_div_w_template_
487
- scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
488
- scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
489
-
490
- # Generate the image
491
- generated_image = gen_one_img(
492
- infinity,
493
- vae,
494
- text_tokenizer,
495
- text_encoder,
496
- prompt,
497
- g_seed=seed,
498
- gt_leak=0,
499
- gt_ls_Bl=None,
500
- cfg_list=cfg,
501
- tau_list=tau,
502
- scale_schedule=scale_schedule,
503
- cfg_insertion_layer=[args.cfg_insertion_layer],
504
- vae_type=args.vae_type,
505
- sampling_per_bits=args.sampling_per_bits,
506
- enable_positive_prompt=enable_positive_prompt,
507
- )
508
-
509
- # Convert the image to RGB and uint8
510
- image = generated_image.cpu().numpy()
511
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
512
- image = np.uint8(image)
513
-
514
- return image
515
- except Exception as e:
516
- print(f"Error generating image: {e}")
517
- return None
518
 
519
  # Set up Gradio interface
520
  with gr.Blocks() as demo:
 
49
  except Exception as e:
50
  print(f"Error downloading weights: {e}")
51
 
52
+ def encode_prompt(text_tokenizer, text_encoder, prompt):
 
 
 
 
 
 
 
 
 
 
 
 
53
  print(f'prompt={prompt}')
54
  captions = [prompt]
55
  tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
 
66
  text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
67
  return text_cond_tuple
68
 
 
 
 
 
 
 
 
 
69
  def enhance_image(image):
70
  for t in range(1):
71
  contrast_image = image.copy()
 
98
  gt_ls_Bl=None,
99
  g_seed=None,
100
  sampling_per_bits=1,
 
101
  ):
102
  sstt = time.time()
103
  if not isinstance(cfg_list, list):
104
  cfg_list = [cfg_list] * len(scale_schedule)
105
  if not isinstance(tau_list, list):
106
  tau_list = [tau_list] * len(scale_schedule)
107
+ text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt)
108
  if negative_prompt:
109
  negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
110
  else:
111
  negative_label_B_or_BLT = None
112
  print(f'cfg: {cfg_list}, tau: {tau_list}')
113
+
114
+ # Set device if not provided
115
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
116
+
117
+ # Set autocast dtype based on bf16 and device support
118
+ if device == 'cuda' and torch.cuda.is_bf16_supported():
119
+ autocast_dtype = torch.bfloat16
120
+ else:
121
+ autocast_dtype = torch.float32
122
+
123
+ torch.cuda.empty_cache()
124
+
125
+ with torch.amp.autocast(device_type=device, dtype=autocast_dtype), torch.no_grad():
126
  stt = time.time()
127
  _, _, img_list = infinity_test.autoregressive_infer_cfg(
128
  vae=vae,
 
160
  return save_file
161
 
162
  def load_tokenizer(t5_path =''):
163
+ print('[Loading tokenizer and text encoder]')
164
  text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
165
  text_tokenizer.model_max_length = 512
166
  text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
 
186
  use_flex_attn=False,
187
  bf16=False,
188
  ):
189
+ print('[Loading Infinity]')
190
 
191
  # Set device if not provided
192
  if device is None:
 
232
  infinity_test.eval()
233
  infinity_test.requires_grad_(False)
234
 
235
+ print('[Load Infinity weights]')
236
  state_dict = torch.load(model_path, map_location=device)
237
  print(infinity_test.load_state_dict(state_dict))
238
 
239
  # Initialize random number generator on the correct device
240
  infinity_test.rng = torch.Generator(device=device)
241
+
242
  return infinity_test
243
 
244
  def transform(pil_img, tgt_h, tgt_w):
 
462
 
463
  # Define the image generation function
464
  @spaces.GPU
465
+ def generate_image(prompt, cfg, tau, h_div_w, seed):
466
+ args.prompt = prompt
467
+ args.cfg = cfg
468
+ args.tau = tau
469
+ args.h_div_w = h_div_w
470
+ args.seed = seed
471
+
472
+ # Find the closest h_div_w_template
473
+ h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates - h_div_w))]
474
+
475
+ # Get scale_schedule based on h_div_w_template_
476
+ scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
477
+ scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
478
+
479
+ # Generate the image
480
+ generated_image = gen_one_img(
481
+ infinity,
482
+ vae,
483
+ text_tokenizer,
484
+ text_encoder,
485
+ prompt,
486
+ g_seed=seed,
487
+ gt_leak=0,
488
+ gt_ls_Bl=None,
489
+ cfg_list=cfg,
490
+ tau_list=tau,
491
+ scale_schedule=scale_schedule,
492
+ cfg_insertion_layer=[args.cfg_insertion_layer],
493
+ vae_type=args.vae_type,
494
+ sampling_per_bits=args.sampling_per_bits,
495
+ )
496
+
497
+ # Convert the image to RGB and uint8
498
+ image = generated_image.cpu().numpy()
499
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
500
+ image = np.uint8(image)
501
+
502
+ return image
503
+
 
 
 
 
 
504
 
505
  # Set up Gradio interface
506
  with gr.Blocks() as demo: