|
import os |
|
import sys |
|
import math |
|
import copy |
|
import time |
|
import queue |
|
import inspect |
|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
import torch.distributed as dist |
|
|
|
from .agent import Agent, parse_nn_args |
|
from .utils import repeat, get_default_device, cutime_stats |
|
from .variable import TaskDemandNow |
|
|
|
from torch.nn.utils import clip_grad_norm_, parameters_to_vector, vector_to_parameters |
|
from torch.utils.data import Dataset, IterableDataset, DataLoader |
|
from torch.optim.lr_scheduler import MultiStepLR |
|
|
|
|
|
class Problem(object): |
|
def __init__(self, isbatch=False): |
|
self.isbatch = isbatch |
|
self.features = [] |
|
self.environment = None |
|
|
|
def pin_memory(self): |
|
for k, v in self.feats.items(): |
|
self.feats[k] = v.pin_memory() |
|
return self |
|
|
|
def __getattr__(self, name): |
|
if name not in ('solution'): |
|
raise AttributeError() |
|
return self.feats.get(name) |
|
|
|
|
|
class Solution(object): |
|
def __init__(self, cost=None): |
|
self.cost = cost |
|
self.worker_task_sequence = None |
|
|
|
|
|
class WrapDataset(Dataset): |
|
def __init__(self, dataset, solver): |
|
self._dataset = [solver.to_batch(p) for p in dataset] |
|
|
|
def __getitem__(self, index): |
|
return self._dataset[index] |
|
|
|
def __len__(self): |
|
return len(self._dataset) |
|
|
|
|
|
class WrapIterator: |
|
def __init__(self, iterator, solver): |
|
self._iterator = iterator |
|
self._solver = solver |
|
|
|
def __next__(self): |
|
p = next(self._iterator) |
|
p = self._solver.to_batch(p, False) |
|
return p |
|
|
|
|
|
class WrapIterableDataset(IterableDataset): |
|
def __init__(self, dataset, solver): |
|
self._dataset = dataset |
|
self._solver = solver |
|
|
|
def __iter__(self): |
|
return WrapIterator(iter(self._dataset), self._solver) |
|
|
|
|
|
class CyclicIterator: |
|
def __init__(self, iterable): |
|
self._iterable = iterable |
|
self._iterator = iter(iterable) |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
try: |
|
return next(self._iterator) |
|
except StopIteration: |
|
self._iterator = iter(self._iterable) |
|
return next(self._iterator) |
|
|
|
|
|
class BufferedIterator: |
|
def __init__(self, iterator, size, reuse): |
|
self._iterator = iterator |
|
self._reuse = reuse |
|
self._queue = queue.Queue(size) |
|
self._buffer = [] |
|
self._iter_step = 0 |
|
|
|
def __next__(self): |
|
if not self._queue.full() or self._iter_step % self._reuse == 0: |
|
problem = next(self._iterator) |
|
if self._queue.full(): |
|
index = self._queue.get() |
|
self._buffer[index] = problem |
|
else: |
|
index = len(self._buffer) |
|
self._buffer.append(problem) |
|
self._queue.put(index) |
|
self._iter_step += 1 |
|
index = torch.randint(0, len(self._buffer), (1,)).item() |
|
return self._buffer[index] |
|
|
|
|
|
class Solver(object): |
|
def __init__(self, device=None, nn_args=None): |
|
|
|
if device is None: |
|
self.device = get_default_device() |
|
elif device == 'cuda': |
|
self.device = get_default_device() |
|
assert self.device.type == 'cuda', 'no cuda device available!' |
|
else: |
|
self.device = torch.device(device) |
|
|
|
if nn_args is None: |
|
nn_args = {} |
|
self.nn_args = nn_args |
|
|
|
self.agent = None |
|
|
|
def parse_nn_args(self, problem): |
|
parse_nn_args(problem, self.nn_args) |
|
|
|
def new_agent(self): |
|
return Agent(self.nn_args) |
|
|
|
def train(self, agent_filename, train_dataset, valid_dataset, **kwargs): |
|
if dist.is_initialized(): |
|
torch.manual_seed(torch.initial_seed() + dist.get_rank() * 20000) |
|
|
|
train_dataset_workers = kwargs.pop('train_dataset_workers', 1) |
|
train_dataset_buffers = kwargs.pop('train_dataset_buffers', 2) |
|
valid_dataset_workers = kwargs.pop('valid_dataset_workers', 1) |
|
valid_dataset_buffers = kwargs.pop('valid_dataset_buffers', 2) |
|
|
|
train_dataset = self.wrap_dataset(train_dataset, train_dataset_workers, |
|
train_dataset_buffers, torch.initial_seed() + 1) |
|
valid_dataset = self.wrap_dataset(valid_dataset, valid_dataset_workers, |
|
valid_dataset_buffers, torch.initial_seed() + 10001) |
|
|
|
if self.device.type == 'cuda': |
|
with torch.cuda.device(cuda_or_none(self.device)): |
|
self.do_train(agent_filename, train_dataset, valid_dataset, **kwargs) |
|
else: |
|
self.do_train(agent_filename, train_dataset, valid_dataset, **kwargs) |
|
|
|
def do_train(self, agent_filename, train_dataset, valid_dataset, reuse_buffer=0, reuse_times=1, on_policy=True, |
|
advpow=1, batch_size=512, topk_size=1, init_lr=0.0001, sched_lr=(int(1e10),), gamma_lr=0.5, |
|
warmup_steps=100, log_steps=-1, optim_steps=1, valid_steps=100, max_steps=int(1e10), memopt=1): |
|
|
|
for arg in inspect.getfullargspec(self.do_train)[0][1:]: |
|
if arg not in ('train_dataset', 'valid_dataset'): |
|
print("train_args: {} = {}".format(arg, locals()[arg])) |
|
|
|
if log_steps < 0: |
|
log_steps = valid_steps |
|
|
|
train_dataset = CyclicIterator(train_dataset) |
|
if reuse_buffer > 0: |
|
train_dataset = BufferedIterator(train_dataset, reuse_buffer, reuse_times) |
|
|
|
valid_dataset = list(valid_dataset) |
|
|
|
if dist.is_initialized() and dist.get_rank() != 0: |
|
dist.barrier() |
|
|
|
if agent_filename is not None and os.path.exists(agent_filename): |
|
saved_state = torch.load(agent_filename, map_location='cpu') |
|
self.nn_args = saved_state['nn_args'] |
|
else: |
|
saved_state = None |
|
self.parse_nn_args(valid_dataset[0]) |
|
|
|
step = 0 |
|
start_step = 0 |
|
self.agent = self.new_agent().train() |
|
self.agent.to(self.device) |
|
self.print_nn_args() |
|
|
|
best_agent = copy.deepcopy(self.agent).eval() |
|
min_valid_cost = math.inf |
|
|
|
optimizer = torch.optim.Adam(self.agent.parameters(), lr=init_lr) |
|
scheduler = MultiStepLR(optimizer, milestones=sched_lr, gamma=gamma_lr) |
|
|
|
def do_save_state(rng_state, cuda_rng_state): |
|
if agent_filename is not None: |
|
save_data = {'step': step, 'rng_state': rng_state} |
|
if cuda_rng_state is not None: |
|
save_data['cuda_rng_state'] = cuda_rng_state |
|
save_data['nn_args'] = self.agent.nn_args_dict() |
|
save_data['agent_state'] = self.agent.state_dict() |
|
save_data['best_agent_state'] = best_agent.state_dict() |
|
save_data['optimizer_state'] = optimizer.state_dict() |
|
save_data['scheduler_state'] = scheduler.state_dict() |
|
torch.save(save_data, agent_filename) |
|
|
|
def valid_sched_save(step): |
|
if dist.is_initialized(): |
|
params = parameters_to_vector(self.agent.parameters()) |
|
params_clone = params.clone() |
|
dist.broadcast(params_clone, 0) |
|
assert torch.all(params == params_clone) |
|
|
|
rng_state = torch.get_rng_state() |
|
cuda_rng_state = None |
|
if self.device.type == 'cuda': |
|
cuda_rng_state = torch.cuda.get_rng_state(self.device) |
|
|
|
print("{} - step={}, validate...".format(time.strftime("%Y-%m-%d %H:%M:%S"), step)) |
|
sys.stdout.flush() |
|
|
|
if self.device.type == 'cuda': |
|
torch.cuda.synchronize(self.device) |
|
start_time = time.time() |
|
valid_result = self.validate(valid_dataset, batch_size) |
|
avg_cost1, avg_cost2, avg_feasible = valid_result |
|
if self.device.type == 'cuda': |
|
torch.cuda.synchronize(self.device) |
|
|
|
duration = time.time() - start_time |
|
|
|
if step > 0: |
|
scheduler.step() |
|
|
|
if not dist.is_initialized() or dist.get_rank() == 0: |
|
do_save_state(rng_state, cuda_rng_state) |
|
|
|
strftime = time.strftime("%Y-%m-%d %H:%M:%S") |
|
print("{} - step={}, cost=[{:.6g}, {:.6g}], feasible={:.0%}".format( |
|
strftime, step, avg_cost1, avg_cost2, avg_feasible)) |
|
print("{} - step={}, min_valid_cost={:.6g}, time={:.3f}s".format( |
|
strftime, step, min(min_valid_cost, avg_cost2), duration)) |
|
print("---------------------------------------------------------------------------------------") |
|
sys.stdout.flush() |
|
return avg_cost2 |
|
|
|
if saved_state is not None: |
|
start_step = saved_state['step'] |
|
|
|
if not dist.is_initialized() or dist.get_rank() == 0: |
|
torch.set_rng_state(saved_state['rng_state']) |
|
if torch.cuda.is_available(): |
|
torch.cuda.set_rng_state(saved_state['cuda_rng_state'], self.device) |
|
|
|
best_agent.load_state_dict(saved_state['best_agent_state']) |
|
self.agent.load_state_dict(saved_state['best_agent_state']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'optimizer_state' in saved_state: |
|
optimizer.load_state_dict(saved_state['optimizer_state']) |
|
if 'scheduler_state' in saved_state: |
|
scheduler.load_state_dict(saved_state['scheduler_state']) |
|
else: |
|
if dist.is_initialized() and dist.get_rank() == 0: |
|
rng_state = torch.get_rng_state() |
|
cuda_rng_state = None |
|
if self.device.type == 'cuda': |
|
cuda_rng_state = torch.cuda.get_rng_state(self.device) |
|
do_save_state(rng_state, cuda_rng_state) |
|
|
|
if dist.is_initialized() and dist.get_rank() == 0: |
|
dist.barrier() |
|
|
|
for step in range(start_step, max_steps): |
|
if step % valid_steps == 0: |
|
valid_cost = valid_sched_save(step) |
|
if valid_cost < min_valid_cost: |
|
best_agent.load_state_dict(self.agent.state_dict()) |
|
min_valid_cost = valid_cost |
|
|
|
start_time = time.time() |
|
|
|
|
|
with torch.no_grad(): |
|
problem = next(train_dataset) |
|
if step < warmup_steps: |
|
batch_size_now = batch_size // 2 |
|
else: |
|
batch_size_now = batch_size |
|
problem = self.to_device(problem) |
|
|
|
if not on_policy: |
|
data_agent = best_agent |
|
else: |
|
data_agent = self.agent |
|
|
|
data_agent.eval() |
|
|
|
|
|
if topk_size > 1: |
|
with torch.no_grad(): |
|
batch_size_topk = batch_size_now * topk_size |
|
env, logp = data_agent(problem, batch_size_topk) |
|
cost = env.cost().sum(1).float() |
|
solution = env.worker_task_sequence() |
|
|
|
NP = problem.batch_size |
|
NK = batch_size_now // NP |
|
NS = solution.size(1) |
|
|
|
cost = cost.view(NP, -1) |
|
cost, kidx = cost.topk(NK, 1, False, False) |
|
cost = cost.view(-1) |
|
kidx = kidx[:, :, None, None].expand(-1, -1, NS, 3) |
|
solution = solution.view(NP, -1, NS, 3) |
|
solution = solution.gather(1, kidx).view(-1, NS, 3) |
|
|
|
elif not on_policy: |
|
with torch.no_grad(): |
|
env, logp = data_agent(problem, batch_size_now) |
|
cost = env.cost().sum(1).float() |
|
solution = env.worker_task_sequence() |
|
else: |
|
self.agent.train() |
|
env, logp = self.agent(problem, batch_size_now, memopt=memopt) |
|
cost = env.cost().sum(1).float() |
|
solution = env.worker_task_sequence() |
|
|
|
self.agent.train() |
|
|
|
|
|
with torch.no_grad(): |
|
NP = problem.batch_size |
|
if topk_size > 1: |
|
baseline = cost.view(NP, -1).max(1)[0] |
|
else: |
|
baseline = cost.view(NP, -1).mean(1) |
|
baseline = repeat(baseline, cost.size(0) // NP) |
|
adv = (cost - baseline)[:, None] |
|
adv_norm = adv.norm() |
|
if adv_norm > 0: |
|
adv = adv / adv.norm() * adv.size(0) |
|
adv = adv.sign() * adv.abs().pow(advpow) |
|
|
|
|
|
if topk_size > 1 or not on_policy: |
|
env, logp = self.agent(problem, batch_size_now, solution=solution, memopt=memopt) |
|
|
|
loss = adv * logp |
|
loss = loss.mean() |
|
loss.backward() |
|
|
|
if step % optim_steps == 0: |
|
if dist.is_initialized(): |
|
params = filter(lambda a: a.grad is not None, self.agent.parameters()) |
|
grad_list = [param.grad for param in params] |
|
grad_vector = parameters_to_vector(grad_list) |
|
dist.all_reduce(grad_vector, op=dist.ReduceOp.SUM) |
|
vector_to_parameters(grad_vector, grad_list) |
|
|
|
grad_norm = clip_grad_norm_(self.agent.parameters(), 1) |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
if step % log_steps == 0: |
|
strftime = time.strftime("%Y-%m-%d %H:%M:%S") |
|
lr = optimizer.param_groups[0]['lr'] |
|
duration = time.time() - start_time |
|
with torch.no_grad(): |
|
p = logp.to(torch.float64).sum(1).exp().mean() |
|
print("{} - step={}, grad={:.6g}, lr={:.6g}, p={:.6g}".format( |
|
strftime, step, grad_norm, lr, p)) |
|
|
|
print("{} - step={}, cost={:.6g}, time={:.3f}s".format(strftime, step, cost.mean(), duration)) |
|
print("---------------------------------------------------------------------------------------") |
|
sys.stdout.flush() |
|
|
|
valid_sched_save(step) |
|
|
|
def solve(self, problem, greedy=False, batch_size=512): |
|
if self.device.type == 'cuda': |
|
with torch.cuda.device(cuda_or_none(self.device)): |
|
return self.do_solve(problem, greedy, batch_size) |
|
else: |
|
return self.do_solve(problem, greedy, batch_size) |
|
|
|
def do_solve(self, problem, greedy, batch_size): |
|
isbatch = problem.isbatch |
|
problem = self.to_batch(problem) |
|
problem = self.to_device(problem) |
|
|
|
if self.agent is None: |
|
self.parse_nn_args(problem) |
|
self.agent = self.new_agent() |
|
self.agent.to(self.device) |
|
|
|
self.agent.eval() |
|
|
|
with torch.no_grad(): |
|
env, prob = self.agent(problem, batch_size, greedy, problem.solution) |
|
|
|
NP = problem.batch_size |
|
NR = prob.size(0) // NP |
|
|
|
prob = prob.view(NP, NR, -1) |
|
cost = env.cost().sum(1).view(NP, NR) |
|
feasible = env.feasible().view(NP, NR) |
|
size = list(env.worker_task_sequence().size()) |
|
size = [NP, NR] + size[1:] |
|
worker_task_sequence = env.worker_task_sequence().view(size) |
|
|
|
p_index = torch.arange(NP) |
|
base_cost = cost.max() + 1 |
|
cost[~feasible] += base_cost |
|
cost, s_index = cost.min(1) |
|
feasible = feasible[p_index, s_index] |
|
cost[~feasible] -= base_cost |
|
probability = prob[p_index, s_index].exp() |
|
worker_task_sequence = worker_task_sequence[p_index, s_index] |
|
|
|
if isbatch: |
|
solution = Solution(cost) |
|
solution.feasible = feasible |
|
solution.probability = probability |
|
solution.worker_task_sequence = worker_task_sequence |
|
else: |
|
solution = Solution(cost.item()) |
|
solution.feasible = feasible.item() |
|
solution.probability = probability.squeeze(0) |
|
solution.worker_task_sequence = worker_task_sequence.squeeze(0) |
|
|
|
return solution |
|
|
|
def load_agent(self, filename, strict=True): |
|
if self.device.type == 'cuda': |
|
with torch.cuda.device(cuda_or_none(self.device)): |
|
self.do_load_agent(filename, strict) |
|
else: |
|
self.do_load_agent(filename, strict) |
|
|
|
def do_load_agent(self, filename, strict=True): |
|
saved_state = torch.load(filename, map_location='cpu') |
|
self.nn_args = saved_state['nn_args'] |
|
|
|
self.agent = self.new_agent() |
|
self.agent.to(self.device) |
|
self.agent.load_state_dict(saved_state['best_agent_state'], strict) |
|
self.print_nn_args() |
|
|
|
def to_batch(self, problem, pin_memory=True): |
|
assert not hasattr(problem, 'feats') |
|
|
|
NW = 1 |
|
NT = 1 |
|
NP = 1 |
|
isbatch = problem.isbatch |
|
for k, v in problem.__dict__.items(): |
|
if k.startswith("worker_"): |
|
NW = len(v[0]) if isbatch else len(v) |
|
elif k.startswith("task_"): |
|
NP = len(v) if isbatch else 1 |
|
NT = len(v[0]) if isbatch else len(v) |
|
NWW = NW * 2 |
|
|
|
new_problem = Problem(True) |
|
new_problem.feats = {} |
|
new_problem.device = 'cpu' |
|
|
|
new_problem.batch_size = NP |
|
new_problem.worker_num = NW |
|
new_problem.task_num = NT |
|
|
|
new_problem.features = problem.features |
|
|
|
if type(self) == Solver: |
|
new_problem.variables = problem.variables |
|
new_problem.constraint = problem.constraint |
|
new_problem.objective = problem.objective |
|
new_problem.environment = problem.environment |
|
else: |
|
new_problem.variables = [] |
|
new_problem.constraints = problem.constraints |
|
new_problem.oa_estimate_tasks = problem.oa_estimate_tasks |
|
new_problem.oa_multiple_steps = problem.oa_multiple_steps |
|
|
|
edge_size_list = ((NWW + NT, NWW + NT), (NW + NT, NW + NT)) |
|
|
|
def check_size(f, k, v): |
|
assert f, "size error, feature: {}, size: {}".format(k, tuple(v.size())) |
|
|
|
for k, v in problem.__dict__.items(): |
|
if k == 'solution' and v is not None: |
|
v = to_tensor(k, v, isbatch) |
|
check_size(v.dim() == 3 and v.size(-1) == 3, k, v) |
|
elif k.startswith("worker_task_"): |
|
v = to_tensor(k, v, isbatch) |
|
check_size(v.dim() in (3, 4) and v.size()[1:3] == (NW, NT), k, v) |
|
elif k.startswith("worker_"): |
|
v = to_tensor(k, v, isbatch) |
|
check_size(v.dim() in (2, 3) and v.size(1) == NW, k, v) |
|
elif k.startswith("task_"): |
|
v = to_tensor(k, v, isbatch) |
|
check_size(v.dim() in (2, 3) and v.size(1) == NT, k, v) |
|
elif k.endswith("_matrix"): |
|
v = to_tensor(k, v, isbatch) |
|
check_size(v.dim() in (3, 4) and v.size()[1:3] in edge_size_list, k, v) |
|
if v.size()[1:3] == (NW + NT, NW + NT): |
|
worker_index = torch.arange(NW) |
|
task_index = torch.arange(NT) + NW |
|
index = torch.cat([worker_index, worker_index, task_index]) |
|
index1 = index[:, None] |
|
index2 = index[None, :] |
|
v = v[:, index1, index2] |
|
elif isinstance(v, np.ndarray): |
|
v = torch.tensor(v) |
|
|
|
if isinstance(v, torch.Tensor): |
|
new_problem.feats[k] = v |
|
|
|
if pin_memory and self.device.type == 'cuda': |
|
new_problem.pin_memory() |
|
|
|
return new_problem |
|
|
|
def to_device(self, problem): |
|
|
|
assert hasattr(problem, 'feats') |
|
|
|
new_problem = copy.copy(problem) |
|
new_problem.device = self.device |
|
new_problem.feats = {} |
|
|
|
non_blocking = self.device.type == 'cuda' |
|
for k, v in problem.feats.items(): |
|
v = v.to(self.device, non_blocking=non_blocking) |
|
new_problem.feats[k] = v |
|
|
|
return new_problem |
|
|
|
def validate(self, problem_list, batch_size): |
|
self.agent.eval() |
|
with torch.no_grad(): |
|
valid_result = self.do_validate(problem_list, batch_size) |
|
|
|
self.agent.train() |
|
return valid_result |
|
|
|
def do_validate(self, problem_list, batch_size): |
|
total_cost1 = 0 |
|
total_cost2 = 0 |
|
total_feasible = 0 |
|
total_problem = 0 |
|
start_time = time.time() |
|
for problem in problem_list: |
|
problem = self.to_device(problem) |
|
env, _, = self.agent(problem, batch_size) |
|
|
|
NP = problem.batch_size |
|
cost = env.cost().sum(1).view(NP, -1) |
|
cost1, _ = cost.min(1) |
|
cost2 = cost.mean(1) |
|
feasible = env.feasible().view(NP, -1) |
|
feasible = torch.any(feasible, 1) |
|
|
|
total_cost1 += cost1.sum().item() |
|
total_cost2 += cost2.sum().item() |
|
total_feasible += feasible.int().sum().item() |
|
total_problem += NP |
|
|
|
if dist.is_initialized(): |
|
data = [total_cost1, total_cost2, total_feasible, total_problem] |
|
data = torch.tensor(data, device=self.device) |
|
dist.all_reduce(data, op=dist.ReduceOp.SUM) |
|
total_cost1, total_cost2, total_feasible, total_problem = data.tolist() |
|
|
|
avg_cost1 = total_cost1 / total_problem |
|
avg_cost2 = total_cost2 / total_problem |
|
avg_feasible = total_feasible / total_problem |
|
|
|
return avg_cost1, avg_cost2, avg_feasible |
|
|
|
def wrap_dataset(self, dataset, workers, buffers, seed): |
|
if isinstance(dataset, IterableDataset): |
|
dataset = WrapIterableDataset(dataset, self) |
|
dataset = DataLoader(dataset, batch_size=None, pin_memory=True, |
|
num_workers=workers, prefetch_factor=buffers, |
|
worker_init_fn=lambda worker_id: torch.manual_seed(seed + worker_id)) |
|
else: |
|
if self.device.type == 'cuda': |
|
with torch.cuda.device(cuda_or_none(self.device)): |
|
dataset = WrapDataset(dataset, self) |
|
dataset = DataLoader(dataset, batch_size=None, pin_memory=True, shuffle=True) |
|
else: |
|
dataset = WrapDataset(dataset, self) |
|
dataset = DataLoader(dataset, batch_size=None, pin_memory=True, shuffle=True) |
|
|
|
return dataset |
|
|
|
def print_nn_args(self): |
|
for key, value in self.nn_args.items(): |
|
if type(value) in [int, float, str, bool]: |
|
print("nn_args: {} = {}".format(key, value)) |
|
sys.stdout.flush() |
|
|
|
|
|
def to_tensor(key, value, isbatch): |
|
if isinstance(value, torch.Tensor): |
|
tensor = value.to('cpu') |
|
else: |
|
tensor = torch.tensor(value, device='cpu') |
|
|
|
if not isbatch: |
|
tensor = tensor[None] |
|
|
|
return tensor |
|
|
|
|
|
def cuda_or_none(device): |
|
return device if device.type == 'cuda' else None |
|
|