Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import argparse | |
import yaml | |
import numpy as np | |
import torch | |
from ncsn.ncsnv2 import NCSNv2, NCSNv2Deeper, NCSNv2Deepest, get_sigmas | |
from ncsn.ema import EMAHelper | |
from adapt import ScoreAdapter | |
device = torch.device("cuda") | |
def get_model(config): | |
if config.data.dataset == 'CIFAR10' or config.data.dataset == 'CELEBA': | |
return NCSNv2(config).to(config.device) | |
elif config.data.dataset == "FFHQ": | |
return NCSNv2Deepest(config).to(config.device) | |
elif config.data.dataset == 'LSUN': | |
return NCSNv2Deeper(config).to(config.device) | |
def dict2namespace(config): | |
namespace = argparse.Namespace() | |
for key, value in config.items(): | |
if isinstance(value, dict): | |
new_value = dict2namespace(value) | |
else: | |
new_value = value | |
setattr(namespace, key, new_value) | |
return namespace | |
class NCSN(ScoreAdapter): | |
def __init__(self): | |
config_fname = Path(__file__).resolve().parent / "ncsn" / "bedroom.yml" | |
with config_fname.open("r") as f: | |
config = yaml.safe_load(f) | |
config = dict2namespace(config) | |
config.device = device | |
states = torch.load( | |
self.checkpoint_root() / "ncsn/exp/logs/bedroom/checkpoint_150000.pth" | |
) | |
model = get_model(config) | |
model = torch.nn.DataParallel(model) | |
model.load_state_dict(states[0], strict=True) | |
if config.model.ema: | |
ema_helper = EMAHelper(mu=config.model.ema_rate) | |
ema_helper.register(model) | |
ema_helper.load_state_dict(states[-1]) | |
# HC: update the model param with history ema. | |
# if don't do this the colors of images become strangely saturated. | |
# this is reported in the paper. | |
ema_helper.ema(model) | |
model = model.module # remove DataParallel | |
model.eval() | |
self.model = model | |
self._data_shape = (3, config.data.image_size, config.data.image_size) | |
self.σs = model.sigmas.cpu().numpy() | |
self._device = device | |
def data_shape(self): | |
return self._data_shape | |
def samps_centered(self): | |
return False | |
def σ_max(self): | |
return self.σs[0] | |
def σ_min(self): | |
return self.σs[-1] | |
def denoise(self, xs, σ): | |
σ, j = self.snap_t_to_nearest_tick(σ) | |
N = xs.shape[0] | |
cond_t = torch.tensor([j] * N, dtype=torch.long, device=self.device) | |
score = self.model(xs, cond_t) | |
Ds = xs + score * (σ ** 2) | |
return Ds | |
def unet_is_cond(self): | |
return False | |
def use_cls_guidance(self): | |
return False | |
def snap_t_to_nearest_tick(self, t): | |
j = np.abs(t - self.σs).argmin() | |
return self.σs[j], j | |