GreedRL / examples /cvrp /train.py
先坤
add greedrl
db26c81
raw
history blame
2.7 kB
import sys
import math
import argparse
import torch.distributed as dist
import torch.multiprocessing as mp
import utils
from greedrl import Solver
def do_train(args, rank):
world_size = args.world_size
model_filename = args.model_filename
problem_size = args.problem_size
batch_size = args.batch_size
index = model_filename.rfind('.')
if world_size > 1:
stdout_filename = '{}_r{}.log'.format(model_filename[0:index], rank)
else:
stdout_filename = '{}.log'.format(model_filename[0:index])
stdout = open(stdout_filename, 'a')
sys.stdout = stdout
sys.stderr = stdout
print("args: {}".format(vars(args)))
if world_size > 1:
dist.init_process_group('NCCL', init_method='tcp://127.0.0.1:29500',
rank=rank, world_size=world_size)
problem_batch_size = 8
batch_count = 0
if problem_size == 100:
batch_count = math.ceil(10000 / problem_batch_size)
elif problem_size == 1000:
batch_count = math.ceil(200 / problem_batch_size)
elif problem_size == 2000:
batch_count = math.ceil(100 / problem_batch_size)
elif problem_size == 5000:
batch_count = math.ceil(10 / problem_batch_size)
else:
raise Exception("unsupported problem size: {}".format(problem_size))
nn_args = {
'encode_norm': 'instance',
'encode_layers': 6,
'decode_rnn': 'LSTM'
}
device = None if world_size == 1 else 'cuda:{}'.format(rank)
solver = Solver(device, nn_args)
train_dataset = utils.Dataset(None, problem_batch_size, problem_size)
valid_dataset = utils.Dataset(batch_count, problem_batch_size, problem_size)
solver.train(model_filename, train_dataset, valid_dataset,
train_dataset_workers=5,
batch_size=batch_size,
memopt=10,
topk_size=1,
init_lr=1e-4,
valid_steps=500,
warmup_steps=0)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--model_filename', type=str, help='model file name')
parser.add_argument('--problem_size', default=100, type=int, choices=[100, 1000, 2000, 5000], help='problem size')
parser.add_argument('--batch_size', default=128, type=int, help='batch size for training')
args = parser.parse_args()
processes = []
for rank in range(args.world_size):
p = mp.Process(target=do_train, args=(args, rank))
p.start()
processes.append(p)
for p in processes:
p.join()