|
from pathlib import Path |
|
from omegaconf import OmegaConf |
|
import torch |
|
from ldm.util import instantiate_from_config |
|
import logging |
|
from contextlib import contextmanager |
|
|
|
from contextlib import contextmanager |
|
import logging |
|
|
|
@contextmanager |
|
def all_logging_disabled(highest_level=logging.CRITICAL): |
|
""" |
|
A context manager that will prevent any logging messages |
|
triggered during the body from being processed. |
|
|
|
:param highest_level: the maximum logging level in use. |
|
This would only need to be changed if a custom level greater than CRITICAL |
|
is defined. |
|
|
|
https://gist.github.com/simon-weber/7853144 |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
previous_level = logging.root.manager.disable |
|
|
|
logging.disable(highest_level) |
|
|
|
try: |
|
yield |
|
finally: |
|
logging.disable(previous_level) |
|
|
|
def load_training_dir(train_dir, device, epoch="last"): |
|
"""Load a checkpoint and config from training directory""" |
|
train_dir = Path(train_dir) |
|
ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) |
|
assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" |
|
config = list(train_dir.rglob(f"*-project.yaml")) |
|
assert len(ckpt) > 0, f"didn't find any config in {train_dir}" |
|
if len(config) > 1: |
|
print(f"found {len(config)} matching config files") |
|
config = sorted(config)[-1] |
|
print(f"selecting {config}") |
|
else: |
|
config = config[0] |
|
|
|
|
|
config = OmegaConf.load(config) |
|
return load_model_from_config(config, ckpt[0], device) |
|
|
|
def load_model_from_config(config, ckpt, device="cpu", verbose=False): |
|
"""Loads a model from config and a ckpt |
|
if config is a path will use omegaconf to load |
|
""" |
|
if isinstance(config, (str, Path)): |
|
config = OmegaConf.load(config) |
|
|
|
with all_logging_disabled(): |
|
print(f"Loading model from {ckpt}") |
|
pl_sd = torch.load(ckpt, map_location="cpu") |
|
global_step = pl_sd["global_step"] |
|
sd = pl_sd["state_dict"] |
|
model = instantiate_from_config(config.model) |
|
m, u = model.load_state_dict(sd, strict=False) |
|
if len(m) > 0 and verbose: |
|
print("missing keys:") |
|
print(m) |
|
if len(u) > 0 and verbose: |
|
print("unexpected keys:") |
|
model.to(device) |
|
model.eval() |
|
model.cond_stage_model.device = device |
|
return model |