Commit
·
8cfdc75
1
Parent(s):
452abeb
finish
Browse files- parti_prompts.py +6 -4
parti_prompts.py
CHANGED
@@ -35,15 +35,16 @@ def get_karlo_eval(ckpt):
|
|
35 |
return karlo_eval
|
36 |
|
37 |
def get_if_eval(ckpt):
|
38 |
-
pipe_low = DiffusionPipeline.from_pretrained(ckpt, safety_checker=None, torch_dtype=torch.float16)
|
39 |
pipe_low.enable_model_cpu_offload()
|
40 |
|
41 |
-
pipe_up = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0", safety_checker=None, text_encoder=pipe_low.text_encoder, torch_dtype=torch.float16)
|
42 |
pipe_up.enable_model_cpu_offload()
|
43 |
|
44 |
def if_eval(prompt, generator=None):
|
45 |
-
|
46 |
-
images =
|
|
|
47 |
return images
|
48 |
|
49 |
return if_eval
|
@@ -69,6 +70,7 @@ if __name__ == "__main__":
|
|
69 |
args = parser.parse_args()
|
70 |
|
71 |
dataset = load_dataset("nateraw/parti-prompts")["train"]
|
|
|
72 |
|
73 |
eval_fn = MODELS[args.model_repo_or_id](args.model_repo_or_id)
|
74 |
|
|
|
35 |
return karlo_eval
|
36 |
|
37 |
def get_if_eval(ckpt):
|
38 |
+
pipe_low = DiffusionPipeline.from_pretrained(ckpt, safety_checker=None, watermarker=None, torch_dtype=torch.float16, variant="fp16")
|
39 |
pipe_low.enable_model_cpu_offload()
|
40 |
|
41 |
+
pipe_up = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0", safety_checker=None, watermarker=None, text_encoder=pipe_low.text_encoder, torch_dtype=torch.float16, variant="fp16")
|
42 |
pipe_up.enable_model_cpu_offload()
|
43 |
|
44 |
def if_eval(prompt, generator=None):
|
45 |
+
prompt_embeds, negative_prompt_embeds = pipe_low.encode_prompt(prompt)
|
46 |
+
images = pipe_low(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator, output_type="pt").images
|
47 |
+
images = pipe_up(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, image=images, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator).images
|
48 |
return images
|
49 |
|
50 |
return if_eval
|
|
|
70 |
args = parser.parse_args()
|
71 |
|
72 |
dataset = load_dataset("nateraw/parti-prompts")["train"]
|
73 |
+
# dataset = dataset.select(range(4))
|
74 |
|
75 |
eval_fn = MODELS[args.model_repo_or_id](args.model_repo_or_id)
|
76 |
|