import sys import os.path as osp import torch import unittest import basetest from greedrl import Solver from greedrl.const import * sys.path.append(osp.join(osp.dirname(osp.abspath(__file__)), "../")) from examples.cvrp import cvrp class TestSolver(basetest.TestCase): def test(self): problem_list = cvrp.make_problem(1) nn_args = {} nn_args['decode_rnn'] = 'GRU' solver = Solver(None, nn_args) solver.train(None, problem_list, problem_list, batch_size=32, max_steps=5, memopt=10) solver.train(None, problem_list, problem_list, batch_size=32, max_steps=5, memopt=10, topk_size=10) solver.train(None, problem_list, problem_list, batch_size=32, max_steps=5, memopt=10, on_policy=False) solution = solver.solve(problem_list[0], batch_size=8) assert torch.all(solution.worker_task_sequence[:, -1, 0] == GRL_FINISH) problem_list[0].solution = solution.worker_task_sequence[:, 0:-1, :] solution2 = solver.solve(problem_list[0], batch_size=1) assert torch.all(solution.worker_task_sequence == solution2.worker_task_sequence) if __name__ == '__main__': unittest.main()