import sys import math import argparse import torch.distributed as dist import torch.multiprocessing as mp import utils from greedrl import Solver def do_train(args, rank): world_size = args.world_size model_filename = args.model_filename problem_size = args.problem_size batch_size = args.batch_size index = model_filename.rfind('.') if world_size > 1: stdout_filename = '{}_r{}.log'.format(model_filename[0:index], rank) else: stdout_filename = '{}.log'.format(model_filename[0:index]) stdout = open(stdout_filename, 'a') sys.stdout = stdout sys.stderr = stdout print("args: {}".format(vars(args))) if world_size > 1: dist.init_process_group('NCCL', init_method='tcp://127.0.0.1:29500', rank=rank, world_size=world_size) problem_batch_size = 8 batch_count = 0 if problem_size == 100: batch_count = math.ceil(10000 / problem_batch_size) elif problem_size == 1000: batch_count = math.ceil(200 / problem_batch_size) elif problem_size == 2000: batch_count = math.ceil(100 / problem_batch_size) elif problem_size == 5000: batch_count = math.ceil(10 / problem_batch_size) else: raise Exception("unsupported problem size: {}".format(problem_size)) nn_args = { 'encode_norm': 'instance', 'encode_layers': 6, 'decode_rnn': 'LSTM' } device = None if world_size == 1 else 'cuda:{}'.format(rank) solver = Solver(device, nn_args) train_dataset = utils.Dataset(None, problem_batch_size, problem_size) valid_dataset = utils.Dataset(batch_count, problem_batch_size, problem_size) solver.train(model_filename, train_dataset, valid_dataset, train_dataset_workers=5, batch_size=batch_size, memopt=10, topk_size=1, init_lr=1e-4, valid_steps=500, warmup_steps=0) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--model_filename', type=str, help='model file name') parser.add_argument('--problem_size', default=100, type=int, choices=[100, 1000, 2000, 5000], help='problem size') parser.add_argument('--batch_size', default=128, type=int, help='batch size for training') args = parser.parse_args() processes = [] for rank in range(args.world_size): p = mp.Process(target=do_train, args=(args, rank)) p.start() processes.append(p) for p in processes: p.join()