Spaces:
Running
on
Zero
Running
on
Zero
MohamedRashad
commited on
Commit
·
715c7b0
1
Parent(s):
920cc4d
Refactor encode_prompt function by removing enable_positive_prompt parameter and related logic; streamline image generation process
Browse files
app.py
CHANGED
@@ -49,19 +49,7 @@ def download_infinity_weights(weights_path):
|
|
49 |
except Exception as e:
|
50 |
print(f"Error downloading weights: {e}")
|
51 |
|
52 |
-
def
|
53 |
-
pattern = r'<(.+?):(.+?)>'
|
54 |
-
matches = re.findall(pattern, text)
|
55 |
-
key_val = {}
|
56 |
-
for match in matches:
|
57 |
-
key_val[match[0]] = match[1].lstrip()
|
58 |
-
return key_val
|
59 |
-
|
60 |
-
def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=False):
|
61 |
-
if enable_positive_prompt:
|
62 |
-
print(f'before positive_prompt aug: {prompt}')
|
63 |
-
prompt = aug_with_positive_prompt(prompt)
|
64 |
-
print(f'after positive_prompt aug: {prompt}')
|
65 |
print(f'prompt={prompt}')
|
66 |
captions = [prompt]
|
67 |
tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
|
@@ -78,14 +66,6 @@ def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=F
|
|
78 |
text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
|
79 |
return text_cond_tuple
|
80 |
|
81 |
-
def aug_with_positive_prompt(prompt):
|
82 |
-
for key in ['man', 'woman', 'men', 'women', 'boy', 'girl', 'child', 'person', 'human', 'adult', 'teenager', 'employee',
|
83 |
-
'employer', 'worker', 'mother', 'father', 'sister', 'brother', 'grandmother', 'grandfather', 'son', 'daughter']:
|
84 |
-
if key in prompt:
|
85 |
-
prompt = prompt + '. very smooth faces, good looking faces, face to the camera, perfect facial features'
|
86 |
-
break
|
87 |
-
return prompt
|
88 |
-
|
89 |
def enhance_image(image):
|
90 |
for t in range(1):
|
91 |
contrast_image = image.copy()
|
@@ -118,20 +98,31 @@ def gen_one_img(
|
|
118 |
gt_ls_Bl=None,
|
119 |
g_seed=None,
|
120 |
sampling_per_bits=1,
|
121 |
-
enable_positive_prompt=0,
|
122 |
):
|
123 |
sstt = time.time()
|
124 |
if not isinstance(cfg_list, list):
|
125 |
cfg_list = [cfg_list] * len(scale_schedule)
|
126 |
if not isinstance(tau_list, list):
|
127 |
tau_list = [tau_list] * len(scale_schedule)
|
128 |
-
text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt
|
129 |
if negative_prompt:
|
130 |
negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
|
131 |
else:
|
132 |
negative_label_B_or_BLT = None
|
133 |
print(f'cfg: {cfg_list}, tau: {tau_list}')
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
stt = time.time()
|
136 |
_, _, img_list = infinity_test.autoregressive_infer_cfg(
|
137 |
vae=vae,
|
@@ -169,7 +160,7 @@ def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_
|
|
169 |
return save_file
|
170 |
|
171 |
def load_tokenizer(t5_path =''):
|
172 |
-
print(
|
173 |
text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
|
174 |
text_tokenizer.model_max_length = 512
|
175 |
text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
|
@@ -195,7 +186,7 @@ def load_infinity(
|
|
195 |
use_flex_attn=False,
|
196 |
bf16=False,
|
197 |
):
|
198 |
-
print(
|
199 |
|
200 |
# Set device if not provided
|
201 |
if device is None:
|
@@ -241,13 +232,13 @@ def load_infinity(
|
|
241 |
infinity_test.eval()
|
242 |
infinity_test.requires_grad_(False)
|
243 |
|
244 |
-
print(
|
245 |
state_dict = torch.load(model_path, map_location=device)
|
246 |
print(infinity_test.load_state_dict(state_dict))
|
247 |
|
248 |
# Initialize random number generator on the correct device
|
249 |
infinity_test.rng = torch.Generator(device=device)
|
250 |
-
|
251 |
return infinity_test
|
252 |
|
253 |
def transform(pil_img, tgt_h, tgt_w):
|
@@ -471,50 +462,45 @@ infinity = load_transformer(vae, args)
|
|
471 |
|
472 |
# Define the image generation function
|
473 |
@spaces.GPU
|
474 |
-
def generate_image(prompt, cfg, tau, h_div_w, seed
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
return image
|
515 |
-
except Exception as e:
|
516 |
-
print(f"Error generating image: {e}")
|
517 |
-
return None
|
518 |
|
519 |
# Set up Gradio interface
|
520 |
with gr.Blocks() as demo:
|
|
|
49 |
except Exception as e:
|
50 |
print(f"Error downloading weights: {e}")
|
51 |
|
52 |
+
def encode_prompt(text_tokenizer, text_encoder, prompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
print(f'prompt={prompt}')
|
54 |
captions = [prompt]
|
55 |
tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
|
|
|
66 |
text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
|
67 |
return text_cond_tuple
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def enhance_image(image):
|
70 |
for t in range(1):
|
71 |
contrast_image = image.copy()
|
|
|
98 |
gt_ls_Bl=None,
|
99 |
g_seed=None,
|
100 |
sampling_per_bits=1,
|
|
|
101 |
):
|
102 |
sstt = time.time()
|
103 |
if not isinstance(cfg_list, list):
|
104 |
cfg_list = [cfg_list] * len(scale_schedule)
|
105 |
if not isinstance(tau_list, list):
|
106 |
tau_list = [tau_list] * len(scale_schedule)
|
107 |
+
text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt)
|
108 |
if negative_prompt:
|
109 |
negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
|
110 |
else:
|
111 |
negative_label_B_or_BLT = None
|
112 |
print(f'cfg: {cfg_list}, tau: {tau_list}')
|
113 |
+
|
114 |
+
# Set device if not provided
|
115 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
116 |
+
|
117 |
+
# Set autocast dtype based on bf16 and device support
|
118 |
+
if device == 'cuda' and torch.cuda.is_bf16_supported():
|
119 |
+
autocast_dtype = torch.bfloat16
|
120 |
+
else:
|
121 |
+
autocast_dtype = torch.float32
|
122 |
+
|
123 |
+
torch.cuda.empty_cache()
|
124 |
+
|
125 |
+
with torch.amp.autocast(device_type=device, dtype=autocast_dtype), torch.no_grad():
|
126 |
stt = time.time()
|
127 |
_, _, img_list = infinity_test.autoregressive_infer_cfg(
|
128 |
vae=vae,
|
|
|
160 |
return save_file
|
161 |
|
162 |
def load_tokenizer(t5_path =''):
|
163 |
+
print('[Loading tokenizer and text encoder]')
|
164 |
text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
|
165 |
text_tokenizer.model_max_length = 512
|
166 |
text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
|
|
|
186 |
use_flex_attn=False,
|
187 |
bf16=False,
|
188 |
):
|
189 |
+
print('[Loading Infinity]')
|
190 |
|
191 |
# Set device if not provided
|
192 |
if device is None:
|
|
|
232 |
infinity_test.eval()
|
233 |
infinity_test.requires_grad_(False)
|
234 |
|
235 |
+
print('[Load Infinity weights]')
|
236 |
state_dict = torch.load(model_path, map_location=device)
|
237 |
print(infinity_test.load_state_dict(state_dict))
|
238 |
|
239 |
# Initialize random number generator on the correct device
|
240 |
infinity_test.rng = torch.Generator(device=device)
|
241 |
+
|
242 |
return infinity_test
|
243 |
|
244 |
def transform(pil_img, tgt_h, tgt_w):
|
|
|
462 |
|
463 |
# Define the image generation function
|
464 |
@spaces.GPU
|
465 |
+
def generate_image(prompt, cfg, tau, h_div_w, seed):
|
466 |
+
args.prompt = prompt
|
467 |
+
args.cfg = cfg
|
468 |
+
args.tau = tau
|
469 |
+
args.h_div_w = h_div_w
|
470 |
+
args.seed = seed
|
471 |
+
|
472 |
+
# Find the closest h_div_w_template
|
473 |
+
h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates - h_div_w))]
|
474 |
+
|
475 |
+
# Get scale_schedule based on h_div_w_template_
|
476 |
+
scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
|
477 |
+
scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
|
478 |
+
|
479 |
+
# Generate the image
|
480 |
+
generated_image = gen_one_img(
|
481 |
+
infinity,
|
482 |
+
vae,
|
483 |
+
text_tokenizer,
|
484 |
+
text_encoder,
|
485 |
+
prompt,
|
486 |
+
g_seed=seed,
|
487 |
+
gt_leak=0,
|
488 |
+
gt_ls_Bl=None,
|
489 |
+
cfg_list=cfg,
|
490 |
+
tau_list=tau,
|
491 |
+
scale_schedule=scale_schedule,
|
492 |
+
cfg_insertion_layer=[args.cfg_insertion_layer],
|
493 |
+
vae_type=args.vae_type,
|
494 |
+
sampling_per_bits=args.sampling_per_bits,
|
495 |
+
)
|
496 |
+
|
497 |
+
# Convert the image to RGB and uint8
|
498 |
+
image = generated_image.cpu().numpy()
|
499 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
500 |
+
image = np.uint8(image)
|
501 |
+
|
502 |
+
return image
|
503 |
+
|
|
|
|
|
|
|
|
|
|
|
504 |
|
505 |
# Set up Gradio interface
|
506 |
with gr.Blocks() as demo:
|