LivePortrait / stf /convert.py
yerang's picture
Upload 636 files
36cb39e verified
raw
history blame
376 Bytes
import torch
import numpy as np
def convert():
state_dict = torch.load("mnist_cnn.pt")
tensor = {
key: tensor.cpu().numpy() for key, tensor in state_dict.items()
}
for key, value in tensor.items():
print(key, value.shape)
np.savez("mnist.npz", **tensor)
def main():
convert()
if __name__ == "__main__":
main()