Spaces:
Runtime error
Runtime error
File size: 1,184 Bytes
4409449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
def load_pretrained(cfg, model, logger, phase="train"):
logger.info(f"Loading pretrain model from {cfg.TRAIN.PRETRAINED}")
if phase == "train":
ckpt_path = cfg.TRAIN.PRETRAINED
elif phase == "test":
ckpt_path = cfg.TEST.CHECKPOINTS
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
model.load_state_dict(state_dict, strict=True)
return model
def load_pretrained_vae(cfg, model, logger):
state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE,
map_location="cpu")['state_dict']
logger.info(f"Loading pretrain vae from {cfg.TRAIN.PRETRAINED_VAE}")
# Extract encoder/decoder
from collections import OrderedDict
vae_dict = OrderedDict()
for k, v in state_dict.items():
if "motion_vae" in k:
name = k.replace("motion_vae.", "")
vae_dict[name] = v
elif "vae" in k:
name = k.replace("vae.", "")
vae_dict[name] = v
if hasattr(model, 'vae'):
model.vae.load_state_dict(vae_dict, strict=True)
else:
model.motion_vae.load_state_dict(vae_dict, strict=True)
return model
|