APISR / tools /clean_weight_info.py
HikariDawn's picture
feat: initial push
561c629
raw
history blame contribute delete
871 Bytes
'''
Clean uncessary information in the weight (*.pth)
'''
import torch
if __name__ == "__main__":
weight_path = "saved_models/esrgan_best_generator.pth"
store_path = "1x_APISR_RRDB_GAN_generator.pth"
# Load the checkpoint
checkpoint_g = torch.load(weight_path)
keys = []
for key in checkpoint_g:
keys.append(key)
print(key)
for key in keys:
if key != "model_state_dict":
del checkpoint_g[key]
# Access the weight
old_keys = [key for key in checkpoint_g['model_state_dict']]
for old_key in old_keys:
if old_key[:10] == "_orig_mod.":
new_key = old_key[10:]
checkpoint_g['model_state_dict'][new_key] = checkpoint_g['model_state_dict'][old_key]
del checkpoint_g['model_state_dict'][old_key]
torch.save(checkpoint_g, store_path)