import gc | |
from map_from_diffusers import convert_diffusers_to_flux_checkpoint | |
from safetensors.torch import load_file, save_file | |
import sys | |
import torch | |
### | |
# Code from huggingface/twodgirl | |
# License: apache-2.0 | |
if __name__ == '__main__': | |
sd = convert_diffusers_to_flux_checkpoint(load_file(sys.argv[1])) | |
assert sd['time_in.in_layer.weight'].dtype == torch.float8_e4m3fn | |
print(len(sd)) | |
gc.collect() | |
save_file(sd, sys.argv[2]) | |