|
|
import safetensors.torch |
|
|
import torch |
|
|
import sys |
|
|
|
|
|
|
|
|
|
|
|
cast_to = None |
|
|
if "fp8_e4m3fn" in sys.argv[1]: |
|
|
cast_to = torch.float8_e4m3fn |
|
|
elif "fp16" in sys.argv[1]: |
|
|
cast_to = torch.float16 |
|
|
elif "bf16" in sys.argv[1]: |
|
|
cast_to = torch.bfloat16 |
|
|
|
|
|
replace_keys = {"all_final_layer.2-1.": "final_layer.", |
|
|
"all_x_embedder.2-1.": "x_embedder.", |
|
|
".attention.to_out.0.bias": ".attention.out.bias", |
|
|
".attention.norm_k.weight": ".attention.k_norm.weight", |
|
|
".attention.norm_q.weight": ".attention.q_norm.weight", |
|
|
".attention.to_out.0.weight": ".attention.out.weight" |
|
|
} |
|
|
|
|
|
out_sd = {} |
|
|
for f in sys.argv[2:]: |
|
|
sd = safetensors.torch.load_file(f) |
|
|
cc = None |
|
|
for k in sd: |
|
|
w = sd[k] |
|
|
|
|
|
if cast_to is not None: |
|
|
w = w.to(cast_to) |
|
|
k_out = k |
|
|
if k_out.endswith(".attention.to_out.0.bias"): |
|
|
continue |
|
|
if k_out.endswith(".attention.to_k.weight"): |
|
|
cc = [w] |
|
|
continue |
|
|
if k_out.endswith(".attention.to_q.weight"): |
|
|
cc = [w] + cc |
|
|
continue |
|
|
if k_out.endswith(".attention.to_v.weight"): |
|
|
cc = cc + [w] |
|
|
w = torch.cat(cc, dim=0) |
|
|
k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight") |
|
|
|
|
|
for r, rr in replace_keys.items(): |
|
|
k_out = k_out.replace(r, rr) |
|
|
out_sd[k_out] = w |
|
|
|
|
|
|
|
|
|
|
|
safetensors.torch.save_file(out_sd, sys.argv[1]) |
|
|
|