lixiang46 commited on
Commit
02843f1
1 Parent(s): 6e2cf60
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -6,8 +6,8 @@ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
6
  from kolors.pipelines import pipeline_stable_diffusion_xl_chatglm_256_ipadapter, pipeline_stable_diffusion_xl_chatglm_256
7
  from kolors.models.modeling_chatglm import ChatGLMModel
8
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9
- from kolors.models.unet_2d_condition import UNet2DConditionModel
10
- from diffusers import AutoencoderKL, EulerDiscreteScheduler
11
  import gradio as gr
12
  import numpy as np
13
 
@@ -17,12 +17,12 @@ ckpt_IPA_dir = '/home/lixiang46/Kolors/weights/Kolors-IP-Adapter-Plus'
17
  # ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
18
  # ckpt_IPA_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus")
19
 
20
- # Load models
21
  text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
22
  tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
23
  vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
24
  scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
25
- unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
 
26
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_IPA_dir}/image_encoder',ignore_mismatched_sizes=True).to(dtype=torch.float16, device=device)
27
  ip_img_size = 336
28
  clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
@@ -30,7 +30,7 @@ clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_siz
30
  pipe_t2i = pipeline_stable_diffusion_xl_chatglm_256.StableDiffusionXLPipeline(
31
  vae=vae,text_encoder=text_encoder,
32
  tokenizer=tokenizer,
33
- unet=unet,
34
  scheduler=scheduler,
35
  force_zeros_for_empty_prompt=False
36
  ).to(device)
@@ -39,7 +39,7 @@ pipe_i2i = pipeline_stable_diffusion_xl_chatglm_256_ipadapter.StableDiffusionXLP
39
  vae=vae,
40
  text_encoder=text_encoder,
41
  tokenizer=tokenizer,
42
- unet=unet,
43
  scheduler=scheduler,
44
  image_encoder=image_encoder,
45
  feature_extractor=clip_image_processor,
@@ -126,7 +126,7 @@ with gr.Blocks(css=css) as demo:
126
  )
127
  run_button = gr.Button("Run", scale=0)
128
  with gr.Row():
129
- ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil")
130
  with gr.Accordion("Advanced Settings", open=False):
131
  negative_prompt = gr.Text(
132
  label="Negative prompt",
 
6
  from kolors.pipelines import pipeline_stable_diffusion_xl_chatglm_256_ipadapter, pipeline_stable_diffusion_xl_chatglm_256
7
  from kolors.models.modeling_chatglm import ChatGLMModel
8
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9
+ from kolors.models import unet_2d_condition
10
+ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
11
  import gradio as gr
12
  import numpy as np
13
 
 
17
  # ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
18
  # ckpt_IPA_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus")
19
 
 
20
  text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
21
  tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
22
  vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
23
  scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
24
+ unet_t2i = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
25
+ unet_i2i = unet_2d_condition.UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
26
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_IPA_dir}/image_encoder',ignore_mismatched_sizes=True).to(dtype=torch.float16, device=device)
27
  ip_img_size = 336
28
  clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
 
30
  pipe_t2i = pipeline_stable_diffusion_xl_chatglm_256.StableDiffusionXLPipeline(
31
  vae=vae,text_encoder=text_encoder,
32
  tokenizer=tokenizer,
33
+ unet=unet_t2i,
34
  scheduler=scheduler,
35
  force_zeros_for_empty_prompt=False
36
  ).to(device)
 
39
  vae=vae,
40
  text_encoder=text_encoder,
41
  tokenizer=tokenizer,
42
+ unet=unet_i2i,
43
  scheduler=scheduler,
44
  image_encoder=image_encoder,
45
  feature_extractor=clip_image_processor,
 
126
  )
127
  run_button = gr.Button("Run", scale=0)
128
  with gr.Row():
129
+ ip_adapter_image = gr.Image(label="IP-Adapter Image (optional)", type="pil")
130
  with gr.Accordion("Advanced Settings", open=False):
131
  negative_prompt = gr.Text(
132
  label="Negative prompt",