Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
from omegaconf import OmegaConf | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.models.diffusion.plms import PLMSSampler | |
from ldm.modules.encoders.adapter import Adapter, StyleAdapter, Adapter_light | |
from ldm.modules.extra_condition.api import ExtraCondition | |
from ldm.util import fix_cond_shapes, load_model_from_config, read_state_dict | |
DEFAULT_NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \ | |
'fewer digits, cropped, worst quality, low quality' | |
def get_base_argument_parser() -> argparse.ArgumentParser: | |
"""get the base argument parser for inference scripts""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--outdir', | |
type=str, | |
help='dir to write results to', | |
default=None, | |
) | |
parser.add_argument( | |
'--prompt', | |
type=str, | |
nargs='?', | |
default=None, | |
help='positive prompt', | |
) | |
parser.add_argument( | |
'--neg_prompt', | |
type=str, | |
default=DEFAULT_NEGATIVE_PROMPT, | |
help='negative prompt', | |
) | |
parser.add_argument( | |
'--cond_path', | |
type=str, | |
default=None, | |
help='condition image path', | |
) | |
parser.add_argument( | |
'--cond_inp_type', | |
type=str, | |
default='image', | |
help='the type of the input condition image, take depth T2I as example, the input can be raw image, ' | |
'which depth will be calculated, or the input can be a directly a depth map image', | |
) | |
parser.add_argument( | |
'--sampler', | |
type=str, | |
default='ddim', | |
choices=['ddim', 'plms'], | |
help='sampling algorithm, currently, only ddim and plms are supported, more are on the way', | |
) | |
parser.add_argument( | |
'--steps', | |
type=int, | |
default=50, | |
help='number of sampling steps', | |
) | |
parser.add_argument( | |
'--sd_ckpt', | |
type=str, | |
default='models/sd-v1-4.ckpt', | |
help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported', | |
) | |
parser.add_argument( | |
'--vae_ckpt', | |
type=str, | |
default=None, | |
help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded', | |
) | |
parser.add_argument( | |
'--adapter_ckpt', | |
type=str, | |
default=None, | |
help='path to checkpoint of adapter', | |
) | |
parser.add_argument( | |
'--config', | |
type=str, | |
default='configs/stable-diffusion/sd-v1-inference.yaml', | |
help='path to config which constructs SD model', | |
) | |
parser.add_argument( | |
'--max_resolution', | |
type=float, | |
default=512 * 512, | |
help='max image height * width, only for computer with limited vram', | |
) | |
parser.add_argument( | |
'--resize_short_edge', | |
type=int, | |
default=None, | |
help='resize short edge of the input image, if this arg is set, max_resolution will not be used', | |
) | |
parser.add_argument( | |
'--C', | |
type=int, | |
default=4, | |
help='latent channels', | |
) | |
parser.add_argument( | |
'--f', | |
type=int, | |
default=8, | |
help='downsampling factor', | |
) | |
parser.add_argument( | |
'--scale', | |
type=float, | |
default=7.5, | |
help='unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))', | |
) | |
parser.add_argument( | |
'--cond_tau', | |
type=float, | |
default=1.0, | |
help='timestamp parameter that determines until which step the adapter is applied, ' | |
'similar as Prompt-to-Prompt tau', | |
) | |
parser.add_argument( | |
'--style_cond_tau', | |
type=float, | |
default=1.0, | |
help='timestamp parameter that determines until which step the adapter is applied, ' | |
'similar as Prompt-to-Prompt tau', | |
) | |
parser.add_argument( | |
'--cond_weight', | |
type=float, | |
default=1.0, | |
help='the adapter features are multiplied by the cond_weight. The larger the cond_weight, the more aligned ' | |
'the generated image and condition will be, but the generated quality may be reduced', | |
) | |
parser.add_argument( | |
'--seed', | |
type=int, | |
default=42, | |
) | |
parser.add_argument( | |
'--n_samples', | |
type=int, | |
default=4, | |
help='# of samples to generate', | |
) | |
return parser | |
def get_sd_models(opt): | |
""" | |
build stable diffusion model, sampler | |
""" | |
# SD | |
config = OmegaConf.load(f"{opt.config}") | |
model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt) | |
sd_model = model.to(opt.device) | |
# sampler | |
if opt.sampler == 'plms': | |
sampler = PLMSSampler(model) | |
elif opt.sampler == 'ddim': | |
sampler = DDIMSampler(model) | |
else: | |
raise NotImplementedError | |
return sd_model, sampler | |
def get_t2i_adapter_models(opt): | |
config = OmegaConf.load(f"{opt.config}") | |
model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt) | |
adapter_ckpt_path = getattr(opt, f'{opt.which_cond}_adapter_ckpt', None) | |
if adapter_ckpt_path is None: | |
adapter_ckpt_path = getattr(opt, 'adapter_ckpt') | |
adapter_ckpt = read_state_dict(adapter_ckpt_path) | |
new_state_dict = {} | |
for k, v in adapter_ckpt.items(): | |
if not k.startswith('adapter.'): | |
new_state_dict[f'adapter.{k}'] = v | |
else: | |
new_state_dict[k] = v | |
m, u = model.load_state_dict(new_state_dict, strict=False) | |
if len(u) > 0: | |
print(f"unexpected keys in loading adapter ckpt {adapter_ckpt_path}:") | |
print(u) | |
model = model.to(opt.device) | |
# sampler | |
if opt.sampler == 'plms': | |
sampler = PLMSSampler(model) | |
elif opt.sampler == 'ddim': | |
sampler = DDIMSampler(model) | |
else: | |
raise NotImplementedError | |
return model, sampler | |
def get_cond_ch(cond_type: ExtraCondition): | |
if cond_type == ExtraCondition.sketch or cond_type == ExtraCondition.canny: | |
return 1 | |
return 3 | |
def get_adapters(opt, cond_type: ExtraCondition): | |
adapter = {} | |
cond_weight = getattr(opt, f'{cond_type.name}_weight', None) | |
if cond_weight is None: | |
cond_weight = getattr(opt, 'cond_weight') | |
adapter['cond_weight'] = cond_weight | |
if cond_type == ExtraCondition.style: | |
adapter['model'] = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(opt.device) | |
elif cond_type == ExtraCondition.color: | |
adapter['model'] = Adapter_light( | |
cin=64 * get_cond_ch(cond_type), | |
channels=[320, 640, 1280, 1280], | |
nums_rb=4).to(opt.device) | |
else: | |
adapter['model'] = Adapter( | |
cin=64 * get_cond_ch(cond_type), | |
channels=[320, 640, 1280, 1280][:4], | |
nums_rb=2, | |
ksize=1, | |
sk=True, | |
use_conv=False).to(opt.device) | |
ckpt_path = getattr(opt, f'{cond_type.name}_adapter_ckpt', None) | |
if ckpt_path is None: | |
ckpt_path = getattr(opt, 'adapter_ckpt') | |
adapter['model'].load_state_dict(torch.load(ckpt_path)) | |
return adapter | |
def diffusion_inference(opt, model, sampler, adapter_features, append_to_context=None): | |
# get text embedding | |
c = model.get_learned_conditioning([opt.prompt]) | |
if opt.scale != 1.0: | |
uc = model.get_learned_conditioning([opt.neg_prompt]) | |
else: | |
uc = None | |
c, uc = fix_cond_shapes(model, c, uc) | |
if not hasattr(opt, 'H'): | |
opt.H = 512 | |
opt.W = 512 | |
shape = [opt.C, opt.H // opt.f, opt.W // opt.f] | |
samples_latents, _ = sampler.sample( | |
S=opt.steps, | |
conditioning=c, | |
batch_size=1, | |
shape=shape, | |
verbose=False, | |
unconditional_guidance_scale=opt.scale, | |
unconditional_conditioning=uc, | |
x_T=None, | |
features_adapter=adapter_features, | |
append_to_context=append_to_context, | |
cond_tau=opt.cond_tau, | |
style_cond_tau=opt.style_cond_tau, | |
) | |
x_samples = model.decode_first_stage(samples_latents) | |
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) | |
return x_samples | |