Ashoka74 commited on
Commit
e401952
1 Parent(s): 489b7a4

Create inference_i2mv_sdxl.py

Browse files
Files changed (1) hide show
  1. inference_i2mv_sdxl.py +260 -0
inference_i2mv_sdxl.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from tqdm import tqdm
9
+ from transformers import AutoModelForImageSegmentation
10
+
11
+ from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline
12
+ from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
13
+ from mvadapter.utils import (
14
+ get_orthogonal_camera,
15
+ get_plucker_embeds_from_cameras_ortho,
16
+ make_image_grid,
17
+ )
18
+
19
+
20
+ def prepare_pipeline(
21
+ base_model,
22
+ vae_model,
23
+ unet_model,
24
+ lora_model,
25
+ adapter_path,
26
+ scheduler,
27
+ num_views,
28
+ device,
29
+ dtype,
30
+ ):
31
+ # Load vae and unet if provided
32
+ pipe_kwargs = {}
33
+ if vae_model is not None:
34
+ pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
35
+ if unet_model is not None:
36
+ pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
37
+
38
+ # Prepare pipeline
39
+ pipe: MVAdapterI2MVSDXLPipeline
40
+ pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
41
+
42
+ # Load scheduler if provided
43
+ scheduler_class = None
44
+ if scheduler == "ddpm":
45
+ scheduler_class = DDPMScheduler
46
+ elif scheduler == "lcm":
47
+ scheduler_class = LCMScheduler
48
+
49
+ pipe.scheduler = ShiftSNRScheduler.from_scheduler(
50
+ pipe.scheduler,
51
+ shift_mode="interpolated",
52
+ shift_scale=8.0,
53
+ scheduler_class=scheduler_class,
54
+ )
55
+ pipe.init_custom_adapter(num_views=num_views)
56
+ pipe.load_custom_adapter(
57
+ adapter_path, weight_name="mvadapter_i2mv_sdxl.safetensors"
58
+ )
59
+
60
+ pipe.to(device=device, dtype=dtype)
61
+ pipe.cond_encoder.to(device=device, dtype=dtype)
62
+
63
+ # load lora if provided
64
+ if lora_model is not None:
65
+ model_, name_ = lora_model.rsplit("/", 1)
66
+ pipe.load_lora_weights(model_, weight_name=name_)
67
+
68
+ # vae slicing for lower memory usage
69
+ pipe.enable_vae_slicing()
70
+
71
+ return pipe
72
+
73
+
74
+ def remove_bg(image, net, transform, device):
75
+ image_size = image.size
76
+ input_images = transform(image).unsqueeze(0).to(device)
77
+ with torch.no_grad():
78
+ preds = net(input_images)[-1].sigmoid().cpu()
79
+ pred = preds[0].squeeze()
80
+ pred_pil = transforms.ToPILImage()(pred)
81
+ mask = pred_pil.resize(image_size)
82
+ image.putalpha(mask)
83
+ return image
84
+
85
+
86
+ def preprocess_image(image: Image.Image, height, width):
87
+ image = np.array(image)
88
+ alpha = image[..., 3] > 0
89
+ H, W = alpha.shape
90
+ # get the bounding box of alpha
91
+ y, x = np.where(alpha)
92
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
93
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
94
+ image_center = image[y0:y1, x0:x1]
95
+ # resize the longer side to H * 0.9
96
+ H, W, _ = image_center.shape
97
+ if H > W:
98
+ W = int(W * (height * 0.9) / H)
99
+ H = int(height * 0.9)
100
+ else:
101
+ H = int(H * (width * 0.9) / W)
102
+ W = int(width * 0.9)
103
+ image_center = np.array(Image.fromarray(image_center).resize((W, H)))
104
+ # pad to H, W
105
+ start_h = (height - H) // 2
106
+ start_w = (width - W) // 2
107
+ image = np.zeros((height, width, 4), dtype=np.uint8)
108
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
109
+ image = image.astype(np.float32) / 255.0
110
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
111
+ image = (image * 255).clip(0, 255).astype(np.uint8)
112
+ image = Image.fromarray(image)
113
+
114
+ return image
115
+
116
+
117
+ def run_pipeline(
118
+ pipe,
119
+ num_views,
120
+ text,
121
+ image,
122
+ height,
123
+ width,
124
+ num_inference_steps,
125
+ guidance_scale,
126
+ seed,
127
+ remove_bg_fn=None,
128
+ reference_conditioning_scale=1.0,
129
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
130
+ lora_scale=1.0,
131
+ device="cuda",
132
+ ):
133
+ # Prepare cameras
134
+ cameras = get_orthogonal_camera(
135
+ elevation_deg=[0, 0, 0, 0, 0, 0],
136
+ distance=[1.8] * num_views,
137
+ left=-0.55,
138
+ right=0.55,
139
+ bottom=-0.55,
140
+ top=0.55,
141
+ azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]],
142
+ device=device,
143
+ )
144
+
145
+ plucker_embeds = get_plucker_embeds_from_cameras_ortho(
146
+ cameras.c2w, [1.1] * num_views, width
147
+ )
148
+ control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1)
149
+
150
+ # Prepare image
151
+ reference_image = Image.open(image) if isinstance(image, str) else image
152
+ if remove_bg_fn is not None:
153
+ reference_image = remove_bg_fn(reference_image)
154
+ reference_image = preprocess_image(reference_image, height, width)
155
+ elif reference_image.mode == "RGBA":
156
+ reference_image = preprocess_image(reference_image, height, width)
157
+
158
+ pipe_kwargs = {}
159
+ if seed != -1 and isinstance(seed, int):
160
+ pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
161
+
162
+ images = pipe(
163
+ text,
164
+ height=height,
165
+ width=width,
166
+ num_inference_steps=num_inference_steps,
167
+ guidance_scale=guidance_scale,
168
+ num_images_per_prompt=num_views,
169
+ control_image=control_images,
170
+ control_conditioning_scale=1.0,
171
+ reference_image=reference_image,
172
+ reference_conditioning_scale=reference_conditioning_scale,
173
+ negative_prompt=negative_prompt,
174
+ cross_attention_kwargs={"scale": lora_scale},
175
+ **pipe_kwargs,
176
+ ).images
177
+
178
+ return images, reference_image
179
+
180
+
181
+ if __name__ == "__main__":
182
+ parser = argparse.ArgumentParser()
183
+ # Models
184
+ parser.add_argument(
185
+ "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
186
+ )
187
+ parser.add_argument(
188
+ "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix"
189
+ )
190
+ parser.add_argument("--unet_model", type=str, default=None)
191
+ parser.add_argument("--scheduler", type=str, default=None)
192
+ parser.add_argument("--lora_model", type=str, default=None)
193
+ parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
194
+ parser.add_argument("--num_views", type=int, default=6)
195
+ # Device
196
+ parser.add_argument("--device", type=str, default="cuda")
197
+ # Inference
198
+ parser.add_argument("--image", type=str, required=True)
199
+ parser.add_argument("--text", type=str, default="high quality")
200
+ parser.add_argument("--num_inference_steps", type=int, default=50)
201
+ parser.add_argument("--guidance_scale", type=float, default=3.0)
202
+ parser.add_argument("--seed", type=int, default=-1)
203
+ parser.add_argument("--lora_scale", type=float, default=1.0)
204
+ parser.add_argument("--reference_conditioning_scale", type=float, default=1.0)
205
+ parser.add_argument(
206
+ "--negative_prompt",
207
+ type=str,
208
+ default="watermark, ugly, deformed, noisy, blurry, low contrast",
209
+ )
210
+ parser.add_argument("--output", type=str, default="output.png")
211
+ # Extra
212
+ parser.add_argument("--remove_bg", action="store_true", help="Remove background")
213
+ args = parser.parse_args()
214
+
215
+ pipe = prepare_pipeline(
216
+ base_model=args.base_model,
217
+ vae_model=args.vae_model,
218
+ unet_model=args.unet_model,
219
+ lora_model=args.lora_model,
220
+ adapter_path=args.adapter_path,
221
+ scheduler=args.scheduler,
222
+ num_views=args.num_views,
223
+ device=args.device,
224
+ dtype=torch.float16,
225
+ )
226
+
227
+ if args.remove_bg:
228
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
229
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
230
+ )
231
+ birefnet.to(args.device)
232
+ transform_image = transforms.Compose(
233
+ [
234
+ transforms.Resize((1024, 1024)),
235
+ transforms.ToTensor(),
236
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
237
+ ]
238
+ )
239
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device)
240
+ else:
241
+ remove_bg_fn = None
242
+
243
+ images, reference_image = run_pipeline(
244
+ pipe,
245
+ num_views=args.num_views,
246
+ text=args.text,
247
+ image=args.image,
248
+ height=768,
249
+ width=768,
250
+ num_inference_steps=args.num_inference_steps,
251
+ guidance_scale=args.guidance_scale,
252
+ seed=args.seed,
253
+ lora_scale=args.lora_scale,
254
+ reference_conditioning_scale=args.reference_conditioning_scale,
255
+ negative_prompt=args.negative_prompt,
256
+ device=args.device,
257
+ remove_bg_fn=remove_bg_fn,
258
+ )
259
+ make_image_grid(images, rows=1).save(args.output)
260
+ reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png")