import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as TF from imagedream.camera_utils import get_camera, convert_opengl_to_blender, normalize_camera from imagedream.model_zoo import build_model from imagedream.ldm.models.diffusion.ddim import DDIMSampler from diffusers import DDIMScheduler class ImageDream(nn.Module): def __init__( self, device, model_name='sd-v2.1-base-4view-ipmv', ckpt_path=None, t_range=[0.02, 0.98], ): super().__init__() self.device = device self.model_name = model_name self.ckpt_path = ckpt_path self.model = build_model(self.model_name, ckpt_path=self.ckpt_path).eval().to(self.device) self.model.device = device for p in self.model.parameters(): p.requires_grad_(False) self.dtype = torch.float32 self.num_train_timesteps = 1000 self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.image_embeddings = {} self.embeddings = {} self.scheduler = DDIMScheduler.from_pretrained( "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler", torch_dtype=self.dtype ) @torch.no_grad() def get_image_text_embeds(self, image, prompts, negative_prompts): image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) image_pil = TF.to_pil_image(image[0]) image_embeddings = self.model.get_learned_image_conditioning(image_pil).repeat(5,1,1) # [5, 257, 1280] self.image_embeddings['pos'] = image_embeddings self.image_embeddings['neg'] = torch.zeros_like(image_embeddings) self.image_embeddings['ip_img'] = self.encode_imgs(image) self.image_embeddings['neg_ip_img'] = torch.zeros_like(self.image_embeddings['ip_img']) pos_embeds = self.encode_text(prompts).repeat(5,1,1) neg_embeds = self.encode_text(negative_prompts).repeat(5,1,1) self.embeddings['pos'] = pos_embeds self.embeddings['neg'] = neg_embeds return self.image_embeddings['pos'], self.image_embeddings['neg'], self.image_embeddings['ip_img'], self.image_embeddings['neg_ip_img'], self.embeddings['pos'], self.embeddings['neg'] def encode_text(self, prompt): # prompt: [str] embeddings = self.model.get_learned_conditioning(prompt).to(self.device) return embeddings @torch.no_grad() def refine(self, pred_rgb, camera, guidance_scale=5, steps=50, strength=0.8, ): batch_size = pred_rgb.shape[0] real_batch_size = batch_size // 4 pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) self.scheduler.set_timesteps(steps) init_step = int(steps * strength) latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) camera = camera[:, [0, 2, 1, 3]] # to blender convention (flip y & z axis) camera[:, 1] *= -1 camera = normalize_camera(camera).view(batch_size, 16) # extra view camera = camera.view(real_batch_size, 4, 16) camera = torch.cat([camera, torch.zeros_like(camera[:, :1])], dim=1) # [rB, 5, 16] camera = camera.view(real_batch_size * 5, 16) camera = camera.repeat(2, 1) embeddings = torch.cat([self.embeddings['neg'].repeat(real_batch_size, 1, 1), self.embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) image_embeddings = torch.cat([self.image_embeddings['neg'].repeat(real_batch_size, 1, 1), self.image_embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) ip_img_embeddings= torch.cat([self.image_embeddings['neg_ip_img'].repeat(real_batch_size, 1, 1, 1), self.image_embeddings['ip_img'].repeat(real_batch_size, 1, 1, 1)], dim=0) context = { "context": embeddings, "ip": image_embeddings, "ip_img": ip_img_embeddings, "camera": camera, "num_frames": 4 + 1 } for i, t in enumerate(self.scheduler.timesteps[init_step:]): # extra view latents = latents.view(real_batch_size, 4, 4, 32, 32) latents = torch.cat([latents, torch.zeros_like(latents[:, :1])], dim=1).view(-1, 4, 32, 32) latent_model_input = torch.cat([latents] * 2) tt = torch.cat([t.unsqueeze(0).repeat(real_batch_size * 5)] * 2).to(self.device) noise_pred = self.model.apply_model(latent_model_input, tt, context) noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) # remove extra view noise_pred_uncond = noise_pred_uncond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) noise_pred_cond = noise_pred_cond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) latents = latents.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) latents = self.scheduler.step(noise_pred, t, latents).prev_sample imgs = self.decode_latents(latents) # [1, 3, 512, 512] return imgs def train_step( self, pred_rgb, # [B, C, H, W] camera, # [B, 4, 4] step_ratio=None, guidance_scale=5, as_latent=False, ): batch_size = pred_rgb.shape[0] real_batch_size = batch_size // 4 pred_rgb = pred_rgb.to(self.dtype) if as_latent: latents = F.interpolate(pred_rgb, (32, 32), mode="bilinear", align_corners=False) * 2 - 1 else: # interp to 256x256 to be fed into vae. pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) # encode image into latents with vae, requires grad! latents = self.encode_imgs(pred_rgb_256) if step_ratio is not None: # dreamtime-like # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) else: t = torch.randint(self.min_step, self.max_step + 1, (real_batch_size,), dtype=torch.long, device=self.device).repeat(4) camera = camera[:, [0, 2, 1, 3]] # to blender convention (flip y & z axis) camera[:, 1] *= -1 camera = normalize_camera(camera).view(batch_size, 16) # extra view camera = camera.view(real_batch_size, 4, 16) camera = torch.cat([camera, torch.zeros_like(camera[:, :1])], dim=1) # [rB, 5, 16] camera = camera.view(real_batch_size * 5, 16) camera = camera.repeat(2, 1) embeddings = torch.cat([self.embeddings['neg'].repeat(real_batch_size, 1, 1), self.embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) image_embeddings = torch.cat([self.image_embeddings['neg'].repeat(real_batch_size, 1, 1), self.image_embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) ip_img_embeddings= torch.cat([self.image_embeddings['neg_ip_img'].repeat(real_batch_size, 1, 1, 1), self.image_embeddings['ip_img'].repeat(real_batch_size, 1, 1, 1)], dim=0) context = { "context": embeddings, "ip": image_embeddings, "ip_img": ip_img_embeddings, "camera": camera, "num_frames": 4 + 1 } # predict the noise residual with unet, NO grad! with torch.no_grad(): # add noise noise = torch.randn_like(latents) latents_noisy = self.model.q_sample(latents, t, noise) # [B=4, 4, 32, 32] # extra view t = t.view(real_batch_size, 4) t = torch.cat([t, t[:, :1]], dim=1).view(-1) latents_noisy = latents_noisy.view(real_batch_size, 4, 4, 32, 32) latents_noisy = torch.cat([latents_noisy, torch.zeros_like(latents_noisy[:, :1])], dim=1).view(-1, 4, 32, 32) # pred noise latent_model_input = torch.cat([latents_noisy] * 2) tt = torch.cat([t] * 2) # import kiui # kiui.lo(latent_model_input, t, context['context'], context['camera']) noise_pred = self.model.apply_model(latent_model_input, tt, context) # perform guidance (high scale from paper!) noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) # remove extra view noise_pred_uncond = noise_pred_uncond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) noise_pred_cond = noise_pred_cond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) grad = (noise_pred - noise) grad = torch.nan_to_num(grad) target = (latents - grad).detach() loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] return loss def decode_latents(self, latents): imgs = self.model.decode_first_stage(latents) imgs = ((imgs + 1) / 2).clamp(0, 1) return imgs def encode_imgs(self, imgs): # imgs: [B, 3, 256, 256] imgs = 2 * imgs - 1 latents = self.model.get_first_stage_encoding(self.model.encode_first_stage(imgs)) return latents # [B, 4, 32, 32] @torch.no_grad() def prompt_to_img( self, image, prompts, negative_prompts="", height=256, width=256, num_inference_steps=50, guidance_scale=5.0, latents=None, elevation=0, azimuth_start=0, ): if isinstance(prompts, str): prompts = [prompts] if isinstance(negative_prompts, str): negative_prompts = [negative_prompts] real_batch_size = len(prompts) batch_size = len(prompts) * 5 # Text embeds -> img latents sampler = DDIMSampler(self.model) shape = [4, height // 8, width // 8] c_ = {"context": self.encode_text(prompts).repeat(5,1,1)} uc_ = {"context": self.encode_text(negative_prompts).repeat(5,1,1)} # image embeddings image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) image_pil = TF.to_pil_image(image[0]) image_embeddings = self.model.get_learned_image_conditioning(image_pil).repeat(5,1,1).to(self.device) c_["ip"] = image_embeddings uc_["ip"] = torch.zeros_like(image_embeddings) ip_img = self.encode_imgs(image) c_["ip_img"] = ip_img uc_["ip_img"] = torch.zeros_like(ip_img) camera = get_camera(4, elevation=elevation, azimuth_start=azimuth_start, extra_view=True) camera = camera.repeat(real_batch_size, 1).to(self.device) c_["camera"] = uc_["camera"] = camera c_["num_frames"] = uc_["num_frames"] = 5 kiui.lo(image_embeddings, ip_img, camera) latents, _ = sampler.sample(S=num_inference_steps, conditioning=c_, batch_size=batch_size, shape=shape, verbose=False, unconditional_guidance_scale=guidance_scale, unconditional_conditioning=uc_, eta=0, x_T=None) # Img latents -> imgs imgs = self.decode_latents(latents) # [4, 3, 256, 256] kiui.lo(latents, imgs) # Img to Numpy imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() imgs = (imgs * 255).round().astype("uint8") return imgs if __name__ == "__main__": import argparse import matplotlib.pyplot as plt import kiui parser = argparse.ArgumentParser() parser.add_argument("image", type=str) parser.add_argument("prompt", type=str) parser.add_argument("--negative", default="", type=str) parser.add_argument("--steps", type=int, default=30) opt = parser.parse_args() device = torch.device("cuda") sd = ImageDream(device) image = kiui.read_image(opt.image, mode='tensor') image = image.permute(2, 0, 1).unsqueeze(0).to(device) while True: imgs = sd.prompt_to_img(image, opt.prompt, opt.negative, num_inference_steps=opt.steps) grid = np.concatenate([ np.concatenate([imgs[0], imgs[1]], axis=1), np.concatenate([imgs[2], imgs[3]], axis=1), ], axis=0) # visualize image plt.imshow(grid) plt.show()