import os import sys import math import copy import time import queue import inspect import torch import numpy as np import torch.nn.functional as F import torch.distributed as dist from .agent import Agent, parse_nn_args from .utils import repeat, get_default_device, cutime_stats from .variable import TaskDemandNow from torch.nn.utils import clip_grad_norm_, parameters_to_vector, vector_to_parameters from torch.utils.data import Dataset, IterableDataset, DataLoader from torch.optim.lr_scheduler import MultiStepLR class Problem(object): def __init__(self, isbatch=False): self.isbatch = isbatch self.features = [] self.environment = None def pin_memory(self): for k, v in self.feats.items(): self.feats[k] = v.pin_memory() return self def __getattr__(self, name): if name not in ('solution'): raise AttributeError() return self.feats.get(name) class Solution(object): def __init__(self, cost=None): self.cost = cost self.worker_task_sequence = None class WrapDataset(Dataset): def __init__(self, dataset, solver): self._dataset = [solver.to_batch(p) for p in dataset] def __getitem__(self, index): return self._dataset[index] def __len__(self): return len(self._dataset) class WrapIterator: def __init__(self, iterator, solver): self._iterator = iterator self._solver = solver def __next__(self): p = next(self._iterator) p = self._solver.to_batch(p, False) return p class WrapIterableDataset(IterableDataset): def __init__(self, dataset, solver): self._dataset = dataset self._solver = solver def __iter__(self): return WrapIterator(iter(self._dataset), self._solver) class CyclicIterator: def __init__(self, iterable): self._iterable = iterable self._iterator = iter(iterable) def __iter__(self): return self def __next__(self): try: return next(self._iterator) except StopIteration: self._iterator = iter(self._iterable) return next(self._iterator) class BufferedIterator: def __init__(self, iterator, size, reuse): self._iterator = iterator self._reuse = reuse self._queue = queue.Queue(size) self._buffer = [] self._iter_step = 0 def __next__(self): if not self._queue.full() or self._iter_step % self._reuse == 0: problem = next(self._iterator) if self._queue.full(): index = self._queue.get() self._buffer[index] = problem else: index = len(self._buffer) self._buffer.append(problem) self._queue.put(index) self._iter_step += 1 index = torch.randint(0, len(self._buffer), (1,)).item() return self._buffer[index] class Solver(object): def __init__(self, device=None, nn_args=None): if device is None: self.device = get_default_device() elif device == 'cuda': self.device = get_default_device() assert self.device.type == 'cuda', 'no cuda device available!' else: self.device = torch.device(device) if nn_args is None: nn_args = {} self.nn_args = nn_args self.agent = None def parse_nn_args(self, problem): parse_nn_args(problem, self.nn_args) def new_agent(self): return Agent(self.nn_args) def train(self, agent_filename, train_dataset, valid_dataset, **kwargs): if dist.is_initialized(): torch.manual_seed(torch.initial_seed() + dist.get_rank() * 20000) train_dataset_workers = kwargs.pop('train_dataset_workers', 1) train_dataset_buffers = kwargs.pop('train_dataset_buffers', 2) valid_dataset_workers = kwargs.pop('valid_dataset_workers', 1) valid_dataset_buffers = kwargs.pop('valid_dataset_buffers', 2) train_dataset = self.wrap_dataset(train_dataset, train_dataset_workers, train_dataset_buffers, torch.initial_seed() + 1) valid_dataset = self.wrap_dataset(valid_dataset, valid_dataset_workers, valid_dataset_buffers, torch.initial_seed() + 10001) if self.device.type == 'cuda': with torch.cuda.device(cuda_or_none(self.device)): self.do_train(agent_filename, train_dataset, valid_dataset, **kwargs) else: self.do_train(agent_filename, train_dataset, valid_dataset, **kwargs) def do_train(self, agent_filename, train_dataset, valid_dataset, reuse_buffer=0, reuse_times=1, on_policy=True, advpow=1, batch_size=512, topk_size=1, init_lr=0.0001, sched_lr=(int(1e10),), gamma_lr=0.5, warmup_steps=100, log_steps=-1, optim_steps=1, valid_steps=100, max_steps=int(1e10), memopt=1): for arg in inspect.getfullargspec(self.do_train)[0][1:]: if arg not in ('train_dataset', 'valid_dataset'): print("train_args: {} = {}".format(arg, locals()[arg])) if log_steps < 0: log_steps = valid_steps train_dataset = CyclicIterator(train_dataset) if reuse_buffer > 0: train_dataset = BufferedIterator(train_dataset, reuse_buffer, reuse_times) valid_dataset = list(valid_dataset) if dist.is_initialized() and dist.get_rank() != 0: dist.barrier() if agent_filename is not None and os.path.exists(agent_filename): saved_state = torch.load(agent_filename, map_location='cpu') self.nn_args = saved_state['nn_args'] else: saved_state = None self.parse_nn_args(valid_dataset[0]) step = 0 start_step = 0 self.agent = self.new_agent().train() self.agent.to(self.device) self.print_nn_args() best_agent = copy.deepcopy(self.agent).eval() min_valid_cost = math.inf optimizer = torch.optim.Adam(self.agent.parameters(), lr=init_lr) scheduler = MultiStepLR(optimizer, milestones=sched_lr, gamma=gamma_lr) def do_save_state(rng_state, cuda_rng_state): if agent_filename is not None: save_data = {'step': step, 'rng_state': rng_state} if cuda_rng_state is not None: save_data['cuda_rng_state'] = cuda_rng_state save_data['nn_args'] = self.agent.nn_args_dict() save_data['agent_state'] = self.agent.state_dict() save_data['best_agent_state'] = best_agent.state_dict() save_data['optimizer_state'] = optimizer.state_dict() save_data['scheduler_state'] = scheduler.state_dict() torch.save(save_data, agent_filename) def valid_sched_save(step): if dist.is_initialized(): params = parameters_to_vector(self.agent.parameters()) params_clone = params.clone() dist.broadcast(params_clone, 0) assert torch.all(params == params_clone) rng_state = torch.get_rng_state() cuda_rng_state = None if self.device.type == 'cuda': cuda_rng_state = torch.cuda.get_rng_state(self.device) print("{} - step={}, validate...".format(time.strftime("%Y-%m-%d %H:%M:%S"), step)) sys.stdout.flush() if self.device.type == 'cuda': torch.cuda.synchronize(self.device) start_time = time.time() valid_result = self.validate(valid_dataset, batch_size) avg_cost1, avg_cost2, avg_feasible = valid_result if self.device.type == 'cuda': torch.cuda.synchronize(self.device) duration = time.time() - start_time if step > 0: scheduler.step() if not dist.is_initialized() or dist.get_rank() == 0: do_save_state(rng_state, cuda_rng_state) strftime = time.strftime("%Y-%m-%d %H:%M:%S") print("{} - step={}, cost=[{:.6g}, {:.6g}], feasible={:.0%}".format( strftime, step, avg_cost1, avg_cost2, avg_feasible)) print("{} - step={}, min_valid_cost={:.6g}, time={:.3f}s".format( strftime, step, min(min_valid_cost, avg_cost2), duration)) print("---------------------------------------------------------------------------------------") sys.stdout.flush() return avg_cost2 if saved_state is not None: start_step = saved_state['step'] if not dist.is_initialized() or dist.get_rank() == 0: torch.set_rng_state(saved_state['rng_state']) if torch.cuda.is_available(): torch.cuda.set_rng_state(saved_state['cuda_rng_state'], self.device) best_agent.load_state_dict(saved_state['best_agent_state']) self.agent.load_state_dict(saved_state['best_agent_state']) # if 'agent_state' in saved_state: # self.agent.load_state_dict(saved_state['agent_state']) # else: # self.agent.load_state_dict(saved_state['best_agent_state']) if 'optimizer_state' in saved_state: optimizer.load_state_dict(saved_state['optimizer_state']) if 'scheduler_state' in saved_state: scheduler.load_state_dict(saved_state['scheduler_state']) else: if dist.is_initialized() and dist.get_rank() == 0: rng_state = torch.get_rng_state() cuda_rng_state = None if self.device.type == 'cuda': cuda_rng_state = torch.cuda.get_rng_state(self.device) do_save_state(rng_state, cuda_rng_state) if dist.is_initialized() and dist.get_rank() == 0: dist.barrier() for step in range(start_step, max_steps): if step % valid_steps == 0: valid_cost = valid_sched_save(step) if valid_cost < min_valid_cost: best_agent.load_state_dict(self.agent.state_dict()) min_valid_cost = valid_cost start_time = time.time() # problem with torch.no_grad(): problem = next(train_dataset) if step < warmup_steps: batch_size_now = batch_size // 2 else: batch_size_now = batch_size problem = self.to_device(problem) if not on_policy: data_agent = best_agent else: data_agent = self.agent data_agent.eval() # solution if topk_size > 1: with torch.no_grad(): batch_size_topk = batch_size_now * topk_size env, logp = data_agent(problem, batch_size_topk) cost = env.cost().sum(1).float() solution = env.worker_task_sequence() NP = problem.batch_size NK = batch_size_now // NP NS = solution.size(1) cost = cost.view(NP, -1) cost, kidx = cost.topk(NK, 1, False, False) cost = cost.view(-1) kidx = kidx[:, :, None, None].expand(-1, -1, NS, 3) solution = solution.view(NP, -1, NS, 3) solution = solution.gather(1, kidx).view(-1, NS, 3) elif not on_policy: with torch.no_grad(): env, logp = data_agent(problem, batch_size_now) cost = env.cost().sum(1).float() solution = env.worker_task_sequence() else: self.agent.train() env, logp = self.agent(problem, batch_size_now, memopt=memopt) cost = env.cost().sum(1).float() solution = env.worker_task_sequence() self.agent.train() # advantage with torch.no_grad(): NP = problem.batch_size if topk_size > 1: baseline = cost.view(NP, -1).max(1)[0] else: baseline = cost.view(NP, -1).mean(1) baseline = repeat(baseline, cost.size(0) // NP) adv = (cost - baseline)[:, None] adv_norm = adv.norm() if adv_norm > 0: adv = adv / adv.norm() * adv.size(0) adv = adv.sign() * adv.abs().pow(advpow) # backward if topk_size > 1 or not on_policy: env, logp = self.agent(problem, batch_size_now, solution=solution, memopt=memopt) loss = adv * logp loss = loss.mean() loss.backward() if step % optim_steps == 0: if dist.is_initialized(): params = filter(lambda a: a.grad is not None, self.agent.parameters()) grad_list = [param.grad for param in params] grad_vector = parameters_to_vector(grad_list) dist.all_reduce(grad_vector, op=dist.ReduceOp.SUM) vector_to_parameters(grad_vector, grad_list) grad_norm = clip_grad_norm_(self.agent.parameters(), 1) optimizer.step() optimizer.zero_grad() if step % log_steps == 0: strftime = time.strftime("%Y-%m-%d %H:%M:%S") lr = optimizer.param_groups[0]['lr'] duration = time.time() - start_time with torch.no_grad(): p = logp.to(torch.float64).sum(1).exp().mean() print("{} - step={}, grad={:.6g}, lr={:.6g}, p={:.6g}".format( strftime, step, grad_norm, lr, p)) print("{} - step={}, cost={:.6g}, time={:.3f}s".format(strftime, step, cost.mean(), duration)) print("---------------------------------------------------------------------------------------") sys.stdout.flush() valid_sched_save(step) def solve(self, problem, greedy=False, batch_size=512): if self.device.type == 'cuda': with torch.cuda.device(cuda_or_none(self.device)): return self.do_solve(problem, greedy, batch_size) else: return self.do_solve(problem, greedy, batch_size) def do_solve(self, problem, greedy, batch_size): isbatch = problem.isbatch problem = self.to_batch(problem) problem = self.to_device(problem) if self.agent is None: self.parse_nn_args(problem) self.agent = self.new_agent() self.agent.to(self.device) self.agent.eval() with torch.no_grad(): env, prob = self.agent(problem, batch_size, greedy, problem.solution) NP = problem.batch_size NR = prob.size(0) // NP prob = prob.view(NP, NR, -1) cost = env.cost().sum(1).view(NP, NR) feasible = env.feasible().view(NP, NR) size = list(env.worker_task_sequence().size()) size = [NP, NR] + size[1:] worker_task_sequence = env.worker_task_sequence().view(size) p_index = torch.arange(NP) base_cost = cost.max() + 1 cost[~feasible] += base_cost cost, s_index = cost.min(1) feasible = feasible[p_index, s_index] cost[~feasible] -= base_cost probability = prob[p_index, s_index].exp() worker_task_sequence = worker_task_sequence[p_index, s_index] if isbatch: solution = Solution(cost) solution.feasible = feasible solution.probability = probability solution.worker_task_sequence = worker_task_sequence else: solution = Solution(cost.item()) solution.feasible = feasible.item() solution.probability = probability.squeeze(0) solution.worker_task_sequence = worker_task_sequence.squeeze(0) return solution def load_agent(self, filename, strict=True): if self.device.type == 'cuda': with torch.cuda.device(cuda_or_none(self.device)): self.do_load_agent(filename, strict) else: self.do_load_agent(filename, strict) def do_load_agent(self, filename, strict=True): saved_state = torch.load(filename, map_location='cpu') self.nn_args = saved_state['nn_args'] self.agent = self.new_agent() self.agent.to(self.device) self.agent.load_state_dict(saved_state['best_agent_state'], strict) self.print_nn_args() def to_batch(self, problem, pin_memory=True): assert not hasattr(problem, 'feats') NW = 1 NT = 1 NP = 1 isbatch = problem.isbatch for k, v in problem.__dict__.items(): if k.startswith("worker_"): NW = len(v[0]) if isbatch else len(v) elif k.startswith("task_"): NP = len(v) if isbatch else 1 NT = len(v[0]) if isbatch else len(v) NWW = NW * 2 new_problem = Problem(True) new_problem.feats = {} new_problem.device = 'cpu' new_problem.batch_size = NP new_problem.worker_num = NW new_problem.task_num = NT new_problem.features = problem.features if type(self) == Solver: new_problem.variables = problem.variables new_problem.constraint = problem.constraint new_problem.objective = problem.objective new_problem.environment = problem.environment else: new_problem.variables = [] new_problem.constraints = problem.constraints new_problem.oa_estimate_tasks = problem.oa_estimate_tasks new_problem.oa_multiple_steps = problem.oa_multiple_steps edge_size_list = ((NWW + NT, NWW + NT), (NW + NT, NW + NT)) def check_size(f, k, v): assert f, "size error, feature: {}, size: {}".format(k, tuple(v.size())) for k, v in problem.__dict__.items(): if k == 'solution' and v is not None: v = to_tensor(k, v, isbatch) check_size(v.dim() == 3 and v.size(-1) == 3, k, v) elif k.startswith("worker_task_"): v = to_tensor(k, v, isbatch) check_size(v.dim() in (3, 4) and v.size()[1:3] == (NW, NT), k, v) elif k.startswith("worker_"): v = to_tensor(k, v, isbatch) check_size(v.dim() in (2, 3) and v.size(1) == NW, k, v) elif k.startswith("task_"): v = to_tensor(k, v, isbatch) check_size(v.dim() in (2, 3) and v.size(1) == NT, k, v) elif k.endswith("_matrix"): v = to_tensor(k, v, isbatch) check_size(v.dim() in (3, 4) and v.size()[1:3] in edge_size_list, k, v) if v.size()[1:3] == (NW + NT, NW + NT): worker_index = torch.arange(NW) task_index = torch.arange(NT) + NW index = torch.cat([worker_index, worker_index, task_index]) index1 = index[:, None] index2 = index[None, :] v = v[:, index1, index2] elif isinstance(v, np.ndarray): v = torch.tensor(v) if isinstance(v, torch.Tensor): new_problem.feats[k] = v if pin_memory and self.device.type == 'cuda': new_problem.pin_memory() return new_problem def to_device(self, problem): assert hasattr(problem, 'feats') new_problem = copy.copy(problem) new_problem.device = self.device new_problem.feats = {} non_blocking = self.device.type == 'cuda' for k, v in problem.feats.items(): v = v.to(self.device, non_blocking=non_blocking) new_problem.feats[k] = v return new_problem def validate(self, problem_list, batch_size): self.agent.eval() with torch.no_grad(): valid_result = self.do_validate(problem_list, batch_size) self.agent.train() return valid_result def do_validate(self, problem_list, batch_size): total_cost1 = 0 total_cost2 = 0 total_feasible = 0 total_problem = 0 start_time = time.time() for problem in problem_list: problem = self.to_device(problem) env, _, = self.agent(problem, batch_size) NP = problem.batch_size cost = env.cost().sum(1).view(NP, -1) cost1, _ = cost.min(1) cost2 = cost.mean(1) feasible = env.feasible().view(NP, -1) feasible = torch.any(feasible, 1) total_cost1 += cost1.sum().item() total_cost2 += cost2.sum().item() total_feasible += feasible.int().sum().item() total_problem += NP if dist.is_initialized(): data = [total_cost1, total_cost2, total_feasible, total_problem] data = torch.tensor(data, device=self.device) dist.all_reduce(data, op=dist.ReduceOp.SUM) total_cost1, total_cost2, total_feasible, total_problem = data.tolist() avg_cost1 = total_cost1 / total_problem avg_cost2 = total_cost2 / total_problem avg_feasible = total_feasible / total_problem return avg_cost1, avg_cost2, avg_feasible def wrap_dataset(self, dataset, workers, buffers, seed): if isinstance(dataset, IterableDataset): dataset = WrapIterableDataset(dataset, self) dataset = DataLoader(dataset, batch_size=None, pin_memory=True, num_workers=workers, prefetch_factor=buffers, worker_init_fn=lambda worker_id: torch.manual_seed(seed + worker_id)) else: if self.device.type == 'cuda': with torch.cuda.device(cuda_or_none(self.device)): dataset = WrapDataset(dataset, self) dataset = DataLoader(dataset, batch_size=None, pin_memory=True, shuffle=True) else: dataset = WrapDataset(dataset, self) dataset = DataLoader(dataset, batch_size=None, pin_memory=True, shuffle=True) return dataset def print_nn_args(self): for key, value in self.nn_args.items(): if type(value) in [int, float, str, bool]: print("nn_args: {} = {}".format(key, value)) sys.stdout.flush() def to_tensor(key, value, isbatch): if isinstance(value, torch.Tensor): tensor = value.to('cpu') else: tensor = torch.tensor(value, device='cpu') if not isbatch: tensor = tensor[None] return tensor def cuda_or_none(device): return device if device.type == 'cuda' else None