|
import sys |
|
import os.path as osp |
|
import torch |
|
import unittest |
|
|
|
import basetest |
|
from greedrl import Solver |
|
from greedrl.const import * |
|
|
|
sys.path.append(osp.join(osp.dirname(osp.abspath(__file__)), "../")) |
|
from examples.cvrp import cvrp |
|
|
|
|
|
class TestSolver(basetest.TestCase): |
|
def test(self): |
|
problem_list = cvrp.make_problem(1) |
|
|
|
nn_args = {} |
|
nn_args['decode_rnn'] = 'GRU' |
|
solver = Solver(None, nn_args) |
|
|
|
solver.train(None, problem_list, problem_list, |
|
batch_size=32, max_steps=5, memopt=10) |
|
|
|
solver.train(None, problem_list, problem_list, |
|
batch_size=32, max_steps=5, memopt=10, topk_size=10) |
|
|
|
solver.train(None, problem_list, problem_list, |
|
batch_size=32, max_steps=5, memopt=10, on_policy=False) |
|
|
|
solution = solver.solve(problem_list[0], batch_size=8) |
|
assert torch.all(solution.worker_task_sequence[:, -1, 0] == GRL_FINISH) |
|
problem_list[0].solution = solution.worker_task_sequence[:, 0:-1, :] |
|
|
|
solution2 = solver.solve(problem_list[0], batch_size=1) |
|
assert torch.all(solution.worker_task_sequence == solution2.worker_task_sequence) |
|
|
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|