booxel / BOOXEL /util.py
yanranxiaoxi's picture
First commit
00fc29f verified
raw
history blame
5.8 kB
import os
import torch
import numpy as np
import cv2
from PIL import Image
from torch.nn.functional import interpolate
from omegaconf import OmegaConf
from sgm.util import instantiate_from_config
def get_state_dict(d):
return d.get('state_dict', d)
def load_state_dict(ckpt_path, location='cpu'):
_, extension = os.path.splitext(ckpt_path)
if extension.lower() == ".safetensors":
import safetensors.torch
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
state_dict = get_state_dict(state_dict)
print(f'从 [{ckpt_path}] 加载 state_dict')
return state_dict
def create_model(config_path):
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
print(f'从 [{config_path}] 加载模型配置')
return model
def create_BOOXEL_model(config_path, BOOXEL_sign=None, load_default_setting=False):
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
print(f'从 [{config_path}] 加载模型配置')
if config.SDXL_CKPT is not None:
model.load_state_dict(load_state_dict(config.SDXL_CKPT), strict=False)
if config.BOOXEL_CKPT is not None:
model.load_state_dict(load_state_dict(config.BOOXEL_CKPT), strict=False)
if BOOXEL_sign is not None:
assert BOOXEL_sign in ['F', 'Q']
if BOOXEL_sign == 'F':
model.load_state_dict(load_state_dict(config.BOOXEL_CKPT_F), strict=False)
elif BOOXEL_sign == 'Q':
model.load_state_dict(load_state_dict(config.BOOXEL_CKPT_Q), strict=False)
if load_default_setting:
default_setting = config.default_setting
return model, default_setting
return model
def load_QF_ckpt(config_path):
config = OmegaConf.load(config_path)
ckpt_F = torch.load(config.BOOXEL_CKPT_F, map_location='cpu')
ckpt_Q = torch.load(config.BOOXEL_CKPT_Q, map_location='cpu')
return ckpt_Q, ckpt_F
def PIL2Tensor(img, upsacle=1, min_size=1024, fix_resize=None):
'''
PIL.Image -> Tensor[C, H, W], RGB, [-1, 1]
'''
# 大小
w, h = img.size
w *= upsacle
h *= upsacle
w0, h0 = round(w), round(h)
if min(w, h) < min_size:
_upsacle = min_size / min(w, h)
w *= _upsacle
h *= _upsacle
if fix_resize is not None:
_upsacle = fix_resize / min(w, h)
w *= _upsacle
h *= _upsacle
w0, h0 = round(w), round(h)
w = int(np.round(w / 64.0)) * 64
h = int(np.round(h / 64.0)) * 64
x = img.resize((w, h), Image.BICUBIC)
x = np.array(x).round().clip(0, 255).astype(np.uint8)
x = x / 255 * 2 - 1
x = torch.tensor(x, dtype=torch.float32).permute(2, 0, 1)
return x, h0, w0
def Tensor2PIL(x, h0, w0):
'''
Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
'''
x = x.unsqueeze(0)
x = interpolate(x, size=(h0, w0), mode='bicubic')
x = (x.squeeze(0).permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
return Image.fromarray(x)
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def upscale_image(input_image, upscale, min_size=None, unit_resolution=64):
H, W, C = input_image.shape
H = float(H)
W = float(W)
H *= upscale
W *= upscale
if min_size is not None:
if min(H, W) < min_size:
_upsacle = min_size / min(W, H)
W *= _upsacle
H *= _upsacle
H = int(np.round(H / unit_resolution)) * unit_resolution
W = int(np.round(W / unit_resolution)) * unit_resolution
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
img = img.round().clip(0, 255).astype(np.uint8)
return img
def fix_resize(input_image, size=512, unit_resolution=64):
H, W, C = input_image.shape
H = float(H)
W = float(W)
upscale = size / min(H, W)
H *= upscale
W *= upscale
H = int(np.round(H / unit_resolution)) * unit_resolution
W = int(np.round(W / unit_resolution)) * unit_resolution
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
img = img.round().clip(0, 255).astype(np.uint8)
return img
def Numpy2Tensor(img):
'''
np.array[H, w, C] [0, 255] -> Tensor[C, H, W], RGB, [-1, 1]
'''
# size
img = np.array(img) / 255 * 2 - 1
img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
return img
def Tensor2Numpy(x, h0=None, w0=None):
'''
Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
'''
if h0 is not None and w0 is not None:
x = x.unsqueeze(0)
x = interpolate(x, size=(h0, w0), mode='bicubic')
x = x.squeeze(0)
x = (x.permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
return x
def convert_dtype(dtype_str):
if dtype_str == 'fp32':
return torch.float32
elif dtype_str == 'fp16':
return torch.float16
elif dtype_str == 'bf16':
return torch.bfloat16
else:
raise NotImplementedError