import os import sys sys.path.append(os.getcwd()) sys.path.append(os.path.join(os.getcwd(), "annotator/entityseg")) import cv2 import spaces import einops import torch import gradio as gr import numpy as np from pytorch_lightning import seed_everything from PIL import Image from annotator.util import resize_image, HWC3 from annotator.canny import CannyDetector from annotator.midas import MidasDetector from annotator.entityseg import EntitysegDetector from annotator.openpose import OpenposeDetector from annotator.content import ContentDetector from annotator.cielab import CIELabDetector from models.util import create_model, load_state_dict from models.ddim_hacked import DDIMSampler ''' define conditions ''' max_conditions = 8 condition_types = ["edge", "depth", "seg", "pose", "content", "color"] apply_canny = CannyDetector() apply_midas = MidasDetector() apply_seg = EntitysegDetector() apply_openpose = OpenposeDetector() apply_content = ContentDetector() apply_color = CIELabDetector() processors = { "edge": apply_canny, "depth": apply_midas, "seg": apply_seg, "pose": apply_openpose, "content": apply_content, "color": apply_color, } descriptors = { "edge": "canny", "depth": "depth", "seg": "segmentation", "pose": "openpose", } @torch.no_grad() def get_unconditional_global(c_global): if isinstance(c_global, dict): return {k:torch.zeros_like(v) for k,v in c_global.items()} elif isinstance(c_global, list): return [torch.zeros_like(c) for c in c_global] else: return torch.zeros_like(c_global) @spaces.GPU def process(prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength, scale, seed, eta, global_strength, color_strength, local_strength, *args): seed_everything(seed) conds_and_types = args conds = conds_and_types[0::2] types = conds_and_types[1::2] conds = [c for c in conds if c is not None] types = [t for t in types if t is not None] assert len(conds) == len(types) detected_maps = [] other_maps = [] tasks = [] # initialize global control global_conditions = dict(clipembedding=np.zeros((1, 768), dtype=np.float32), color=np.zeros((1, 180), dtype=np.float32)) global_control = {} for key in global_conditions.keys(): global_cond = torch.from_numpy(global_conditions[key]).unsqueeze(0).repeat(num_samples, 1, 1) global_cond = global_cond.cuda().to(memory_format=torch.contiguous_format).float() global_control[key] = global_cond # initialize local control anchor_image = HWC3(np.zeros((image_resolution, image_resolution, 3)).astype(np.uint8)) oH, oW = anchor_image.shape[:2] H, W, C = resize_image(anchor_image, image_resolution).shape anchor_tensor = ddim_sampler.model.qformer_vis_processor['eval'](Image.fromarray(anchor_image)) local_control = torch.tensor(anchor_tensor).cuda().to(memory_format=torch.contiguous_format).half() task_prompt = '' with torch.no_grad(): # set up local control for cond, typ in zip(conds, types): if typ in ['edge', 'depth', 'seg', 'pose']: oH, oW = cond.shape[:2] cond_image = HWC3(cv2.resize(cond, (W, H))) cond_detected_map = processors[typ](cond_image) cond_detected_map = HWC3(cond_detected_map) detected_maps.append(cond_detected_map) tasks.append(descriptors[typ]) elif typ in ['content']: other_maps.append(cond) content_image = cv2.cvtColor(cond, cv2.COLOR_RGB2BGR) content_emb = apply_content(content_image) global_conditions['clipembedding'] = content_emb elif typ in ['color']: color_hist = apply_color(cond) global_conditions['color'] = color_hist color_palette = apply_color.hist_to_palette(color_hist) # (50, 189, 3) color_palette = cv2.resize(color_palette, (W, H), cv2.INTER_NEAREST) other_maps.append(color_palette) if len(detected_maps) > 0: local_control = torch.cat([ddim_sampler.model.qformer_vis_processor['eval'](Image.fromarray(img)).cuda().unsqueeze(0) for img in detected_maps], dim=1) task_prompt = ' conditioned on ' + ' and '.join(tasks) local_control = local_control.repeat(num_samples, 1, 1, 1) # set up global control for key in global_conditions.keys(): global_cond = torch.from_numpy(global_conditions[key]).unsqueeze(0).repeat(num_samples, 1, 1) global_cond = global_cond.cuda().to(memory_format=torch.contiguous_format).float() global_control[key] = global_cond # set up prompt input_prompt = (prompt + ' ' + task_prompt).strip() # set up cfg uc_local_control = local_control uc_global_control = {k:torch.zeros_like(v) for k,v in global_control.items()} cond = { "local_control": [local_control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "global_control": [global_control], "text": [[input_prompt] * num_samples], } un_cond = { "local_control": [uc_local_control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)], 'global_control': [uc_global_control], "text": [[input_prompt] * num_samples], } shape = (4, H // 8, W // 8) model.control_scales = [strength] * 13 samples, _ = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond, global_strength=global_strength, color_strength=color_strength, local_strength=local_strength) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) results = [x_samples[i] for i in range(num_samples)] results = [cv2.resize(res, (oW, oH)) for res in results] detected_maps = [cv2.resize(maps, (oW, oH)) for maps in detected_maps] return [results, detected_maps+other_maps] def variable_image_outputs(k): if k is None: k = 1 k = int(k) imageboxes = [] for i in range(max_conditions): if i