Spaces:
Running
on
T4
Running
on
T4
''' | |
Clean uncessary information in the weight (*.pth) | |
''' | |
import torch | |
if __name__ == "__main__": | |
weight_path = "saved_models/esrgan_best_generator.pth" | |
store_path = "1x_APISR_RRDB_GAN_generator.pth" | |
# Load the checkpoint | |
checkpoint_g = torch.load(weight_path) | |
keys = [] | |
for key in checkpoint_g: | |
keys.append(key) | |
print(key) | |
for key in keys: | |
if key != "model_state_dict": | |
del checkpoint_g[key] | |
# Access the weight | |
old_keys = [key for key in checkpoint_g['model_state_dict']] | |
for old_key in old_keys: | |
if old_key[:10] == "_orig_mod.": | |
new_key = old_key[10:] | |
checkpoint_g['model_state_dict'][new_key] = checkpoint_g['model_state_dict'][old_key] | |
del checkpoint_g['model_state_dict'][old_key] | |
torch.save(checkpoint_g, store_path) | |