zhiweili commited on
Commit
c9624f8
1 Parent(s): faaa8ba

change scheduler

Browse files
Files changed (1) hide show
  1. app_haircolor_pix2pix.py +4 -5
app_haircolor_pix2pix.py CHANGED
@@ -12,13 +12,14 @@ from enhance_utils import enhance_image
12
  from diffusers import (
13
  StableDiffusionInstructPix2PixPipeline,
14
  EulerAncestralDiscreteScheduler,
 
15
  )
16
 
17
  BASE_MODEL = "timbrooks/instruct-pix2pix"
18
 
19
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- DEFAULT_EDIT_PROMPT = "change hair to blue"
22
 
23
  DEFAULT_CATEGORY = "hair"
24
 
@@ -28,7 +29,7 @@ basepipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
28
  use_safetensors=True,
29
  )
30
 
31
- basepipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(basepipeline.scheduler.config)
32
 
33
  basepipeline = basepipeline.to(DEVICE)
34
 
@@ -46,14 +47,12 @@ def image_to_image(
46
  run_task_time = 0
47
  time_cost_str = ''
48
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
49
- gray_image = input_image.convert("L")
50
- p2p_image = Image.merge("RGB", [gray_image, gray_image, gray_image])
51
 
52
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
53
  generated_image = basepipeline(
54
  generator=generator,
55
  prompt=edit_prompt,
56
- image=p2p_image,
57
  guidance_scale=guidance_scale,
58
  image_guidance_scale=image_guidance_scale,
59
  num_inference_steps=num_steps,
 
12
  from diffusers import (
13
  StableDiffusionInstructPix2PixPipeline,
14
  EulerAncestralDiscreteScheduler,
15
+ DDIMScheduler,
16
  )
17
 
18
  BASE_MODEL = "timbrooks/instruct-pix2pix"
19
 
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
+ DEFAULT_EDIT_PROMPT = "hair to linen-blonde-hair"
23
 
24
  DEFAULT_CATEGORY = "hair"
25
 
 
29
  use_safetensors=True,
30
  )
31
 
32
+ basepipeline.scheduler = DDIMScheduler.from_config(basepipeline.scheduler.config)
33
 
34
  basepipeline = basepipeline.to(DEVICE)
35
 
 
47
  run_task_time = 0
48
  time_cost_str = ''
49
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
 
50
 
51
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
52
  generated_image = basepipeline(
53
  generator=generator,
54
  prompt=edit_prompt,
55
+ image=input_image,
56
  guidance_scale=guidance_scale,
57
  image_guidance_scale=image_guidance_scale,
58
  num_inference_steps=num_steps,