|
|
|
import argparse |
|
import torch |
|
|
|
|
|
def average_models(model_files, fp32=False): |
|
vocab = None |
|
opt = None |
|
avg_model = None |
|
avg_generator = None |
|
|
|
for i, model_file in enumerate(model_files): |
|
m = torch.load(model_file, map_location="cpu") |
|
model_weights = m["model"] |
|
generator_weights = m["generator"] |
|
|
|
if fp32: |
|
for k, v in model_weights.items(): |
|
model_weights[k] = v.float() |
|
for k, v in generator_weights.items(): |
|
generator_weights[k] = v.float() |
|
|
|
if i == 0: |
|
vocab, opt = m["vocab"], m["opt"] |
|
avg_model = model_weights |
|
avg_generator = generator_weights |
|
else: |
|
for k, v in avg_model.items(): |
|
avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1) |
|
|
|
for k, v in avg_generator.items(): |
|
avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1) |
|
|
|
final = { |
|
"vocab": vocab, |
|
"opt": opt, |
|
"optim": None, |
|
"generator": avg_generator, |
|
"model": avg_model, |
|
} |
|
return final |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="") |
|
parser.add_argument( |
|
"-models", "-m", nargs="+", required=True, help="List of models" |
|
) |
|
parser.add_argument("-output", "-o", required=True, help="Output file") |
|
parser.add_argument( |
|
"-fp32", "-f", action="store_true", help="Cast params to float32" |
|
) |
|
opt = parser.parse_args() |
|
|
|
final = average_models(opt.models, opt.fp32) |
|
torch.save(final, opt.output) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|