zhiweili
commited on
Commit
•
c9624f8
1
Parent(s):
faaa8ba
change scheduler
Browse files- app_haircolor_pix2pix.py +4 -5
app_haircolor_pix2pix.py
CHANGED
@@ -12,13 +12,14 @@ from enhance_utils import enhance_image
|
|
12 |
from diffusers import (
|
13 |
StableDiffusionInstructPix2PixPipeline,
|
14 |
EulerAncestralDiscreteScheduler,
|
|
|
15 |
)
|
16 |
|
17 |
BASE_MODEL = "timbrooks/instruct-pix2pix"
|
18 |
|
19 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
|
21 |
-
DEFAULT_EDIT_PROMPT = "
|
22 |
|
23 |
DEFAULT_CATEGORY = "hair"
|
24 |
|
@@ -28,7 +29,7 @@ basepipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
|
28 |
use_safetensors=True,
|
29 |
)
|
30 |
|
31 |
-
basepipeline.scheduler =
|
32 |
|
33 |
basepipeline = basepipeline.to(DEVICE)
|
34 |
|
@@ -46,14 +47,12 @@ def image_to_image(
|
|
46 |
run_task_time = 0
|
47 |
time_cost_str = ''
|
48 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
49 |
-
gray_image = input_image.convert("L")
|
50 |
-
p2p_image = Image.merge("RGB", [gray_image, gray_image, gray_image])
|
51 |
|
52 |
generator = torch.Generator(device=DEVICE).manual_seed(seed)
|
53 |
generated_image = basepipeline(
|
54 |
generator=generator,
|
55 |
prompt=edit_prompt,
|
56 |
-
image=
|
57 |
guidance_scale=guidance_scale,
|
58 |
image_guidance_scale=image_guidance_scale,
|
59 |
num_inference_steps=num_steps,
|
|
|
12 |
from diffusers import (
|
13 |
StableDiffusionInstructPix2PixPipeline,
|
14 |
EulerAncestralDiscreteScheduler,
|
15 |
+
DDIMScheduler,
|
16 |
)
|
17 |
|
18 |
BASE_MODEL = "timbrooks/instruct-pix2pix"
|
19 |
|
20 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
|
22 |
+
DEFAULT_EDIT_PROMPT = "hair to linen-blonde-hair"
|
23 |
|
24 |
DEFAULT_CATEGORY = "hair"
|
25 |
|
|
|
29 |
use_safetensors=True,
|
30 |
)
|
31 |
|
32 |
+
basepipeline.scheduler = DDIMScheduler.from_config(basepipeline.scheduler.config)
|
33 |
|
34 |
basepipeline = basepipeline.to(DEVICE)
|
35 |
|
|
|
47 |
run_task_time = 0
|
48 |
time_cost_str = ''
|
49 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
|
|
|
|
50 |
|
51 |
generator = torch.Generator(device=DEVICE).manual_seed(seed)
|
52 |
generated_image = basepipeline(
|
53 |
generator=generator,
|
54 |
prompt=edit_prompt,
|
55 |
+
image=input_image,
|
56 |
guidance_scale=guidance_scale,
|
57 |
image_guidance_scale=image_guidance_scale,
|
58 |
num_inference_steps=num_steps,
|