File size: 1,265 Bytes
db26c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import time
import random
import argparse
import torch

from greedrl import Problem, Solution, Solver


def run(make_problem, mask_task_ratio=0.1):
    random.seed(123)
    torch.manual_seed(123)
    problem_list = make_problem(1)

    parser = argparse.ArgumentParser(description="")
    parser.add_argument('--device', default=None, type=str)
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--agent_file', default=None, type=str)
    parser.add_argument('--valid_steps', default=5, type=int)
    parser.add_argument('--max_steps', default=10000000, type=int)

    args, _ = parser.parse_known_args()
    for k, v in args.__dict__.items():
        print("arg: {} = {}".format(k, v))

    # rl train
    solver = Solver(device=args.device)
    solver.train(args.agent_file, problem_list, problem_list,
                 batch_size=args.batch_size, valid_steps=args.valid_steps, max_steps=args.max_steps)
    # predict
    solver = Solver(device=args.device)
    if args.agent_file is not None:
        solver.load_agent(args.agent_file)

    print("solve ...")
    start = time.time()
    for problem in problem_list:
        solver.solve(problem, batch_size=args.batch_size)
    print("time: {}s".format(time.time() - start))