import os import torch from collections import OrderedDict def extract(ckpt): a = ckpt["model"] opt = OrderedDict() opt["weight"] = {} for key in a.keys(): if "enc_q" in key: continue opt["weight"][key] = a[key] return opt