File size: 950 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import argparse
import sys
import os
import torch

tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)


def average(model_list_path):
    for i, model_path in enumerate(model_list_path):
        model = torch.load(model_path)
        if i == 0:
            avg_model = model
        else:
            for k, _ in avg_model.items():
                avg_model[k].mul_(i).add_(model[k]).div_(i+1)

    return avg_model


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--model_list_path", nargs="+", required=True,
                        help="Path of the input model list.")
    parser.add_argument("--output_model_path", required=True,
                        help="Path of the output model.")
    args = parser.parse_args()

    avg_model = average(args.model_list_path)
    torch.save(avg_model, args.output_model_path)