GreedRL / greedrl /solver.py
先坤
add greedrl
db26c81
raw
history blame
23.3 kB
import os
import sys
import math
import copy
import time
import queue
import inspect
import torch
import numpy as np
import torch.nn.functional as F
import torch.distributed as dist
from .agent import Agent, parse_nn_args
from .utils import repeat, get_default_device, cutime_stats
from .variable import TaskDemandNow
from torch.nn.utils import clip_grad_norm_, parameters_to_vector, vector_to_parameters
from torch.utils.data import Dataset, IterableDataset, DataLoader
from torch.optim.lr_scheduler import MultiStepLR
class Problem(object):
def __init__(self, isbatch=False):
self.isbatch = isbatch
self.features = []
self.environment = None
def pin_memory(self):
for k, v in self.feats.items():
self.feats[k] = v.pin_memory()
return self
def __getattr__(self, name):
if name not in ('solution'):
raise AttributeError()
return self.feats.get(name)
class Solution(object):
def __init__(self, cost=None):
self.cost = cost
self.worker_task_sequence = None
class WrapDataset(Dataset):
def __init__(self, dataset, solver):
self._dataset = [solver.to_batch(p) for p in dataset]
def __getitem__(self, index):
return self._dataset[index]
def __len__(self):
return len(self._dataset)
class WrapIterator:
def __init__(self, iterator, solver):
self._iterator = iterator
self._solver = solver
def __next__(self):
p = next(self._iterator)
p = self._solver.to_batch(p, False)
return p
class WrapIterableDataset(IterableDataset):
def __init__(self, dataset, solver):
self._dataset = dataset
self._solver = solver
def __iter__(self):
return WrapIterator(iter(self._dataset), self._solver)
class CyclicIterator:
def __init__(self, iterable):
self._iterable = iterable
self._iterator = iter(iterable)
def __iter__(self):
return self
def __next__(self):
try:
return next(self._iterator)
except StopIteration:
self._iterator = iter(self._iterable)
return next(self._iterator)
class BufferedIterator:
def __init__(self, iterator, size, reuse):
self._iterator = iterator
self._reuse = reuse
self._queue = queue.Queue(size)
self._buffer = []
self._iter_step = 0
def __next__(self):
if not self._queue.full() or self._iter_step % self._reuse == 0:
problem = next(self._iterator)
if self._queue.full():
index = self._queue.get()
self._buffer[index] = problem
else:
index = len(self._buffer)
self._buffer.append(problem)
self._queue.put(index)
self._iter_step += 1
index = torch.randint(0, len(self._buffer), (1,)).item()
return self._buffer[index]
class Solver(object):
def __init__(self, device=None, nn_args=None):
if device is None:
self.device = get_default_device()
elif device == 'cuda':
self.device = get_default_device()
assert self.device.type == 'cuda', 'no cuda device available!'
else:
self.device = torch.device(device)
if nn_args is None:
nn_args = {}
self.nn_args = nn_args
self.agent = None
def parse_nn_args(self, problem):
parse_nn_args(problem, self.nn_args)
def new_agent(self):
return Agent(self.nn_args)
def train(self, agent_filename, train_dataset, valid_dataset, **kwargs):
if dist.is_initialized():
torch.manual_seed(torch.initial_seed() + dist.get_rank() * 20000)
train_dataset_workers = kwargs.pop('train_dataset_workers', 1)
train_dataset_buffers = kwargs.pop('train_dataset_buffers', 2)
valid_dataset_workers = kwargs.pop('valid_dataset_workers', 1)
valid_dataset_buffers = kwargs.pop('valid_dataset_buffers', 2)
train_dataset = self.wrap_dataset(train_dataset, train_dataset_workers,
train_dataset_buffers, torch.initial_seed() + 1)
valid_dataset = self.wrap_dataset(valid_dataset, valid_dataset_workers,
valid_dataset_buffers, torch.initial_seed() + 10001)
if self.device.type == 'cuda':
with torch.cuda.device(cuda_or_none(self.device)):
self.do_train(agent_filename, train_dataset, valid_dataset, **kwargs)
else:
self.do_train(agent_filename, train_dataset, valid_dataset, **kwargs)
def do_train(self, agent_filename, train_dataset, valid_dataset, reuse_buffer=0, reuse_times=1, on_policy=True,
advpow=1, batch_size=512, topk_size=1, init_lr=0.0001, sched_lr=(int(1e10),), gamma_lr=0.5,
warmup_steps=100, log_steps=-1, optim_steps=1, valid_steps=100, max_steps=int(1e10), memopt=1):
for arg in inspect.getfullargspec(self.do_train)[0][1:]:
if arg not in ('train_dataset', 'valid_dataset'):
print("train_args: {} = {}".format(arg, locals()[arg]))
if log_steps < 0:
log_steps = valid_steps
train_dataset = CyclicIterator(train_dataset)
if reuse_buffer > 0:
train_dataset = BufferedIterator(train_dataset, reuse_buffer, reuse_times)
valid_dataset = list(valid_dataset)
if dist.is_initialized() and dist.get_rank() != 0:
dist.barrier()
if agent_filename is not None and os.path.exists(agent_filename):
saved_state = torch.load(agent_filename, map_location='cpu')
self.nn_args = saved_state['nn_args']
else:
saved_state = None
self.parse_nn_args(valid_dataset[0])
step = 0
start_step = 0
self.agent = self.new_agent().train()
self.agent.to(self.device)
self.print_nn_args()
best_agent = copy.deepcopy(self.agent).eval()
min_valid_cost = math.inf
optimizer = torch.optim.Adam(self.agent.parameters(), lr=init_lr)
scheduler = MultiStepLR(optimizer, milestones=sched_lr, gamma=gamma_lr)
def do_save_state(rng_state, cuda_rng_state):
if agent_filename is not None:
save_data = {'step': step, 'rng_state': rng_state}
if cuda_rng_state is not None:
save_data['cuda_rng_state'] = cuda_rng_state
save_data['nn_args'] = self.agent.nn_args_dict()
save_data['agent_state'] = self.agent.state_dict()
save_data['best_agent_state'] = best_agent.state_dict()
save_data['optimizer_state'] = optimizer.state_dict()
save_data['scheduler_state'] = scheduler.state_dict()
torch.save(save_data, agent_filename)
def valid_sched_save(step):
if dist.is_initialized():
params = parameters_to_vector(self.agent.parameters())
params_clone = params.clone()
dist.broadcast(params_clone, 0)
assert torch.all(params == params_clone)
rng_state = torch.get_rng_state()
cuda_rng_state = None
if self.device.type == 'cuda':
cuda_rng_state = torch.cuda.get_rng_state(self.device)
print("{} - step={}, validate...".format(time.strftime("%Y-%m-%d %H:%M:%S"), step))
sys.stdout.flush()
if self.device.type == 'cuda':
torch.cuda.synchronize(self.device)
start_time = time.time()
valid_result = self.validate(valid_dataset, batch_size)
avg_cost1, avg_cost2, avg_feasible = valid_result
if self.device.type == 'cuda':
torch.cuda.synchronize(self.device)
duration = time.time() - start_time
if step > 0:
scheduler.step()
if not dist.is_initialized() or dist.get_rank() == 0:
do_save_state(rng_state, cuda_rng_state)
strftime = time.strftime("%Y-%m-%d %H:%M:%S")
print("{} - step={}, cost=[{:.6g}, {:.6g}], feasible={:.0%}".format(
strftime, step, avg_cost1, avg_cost2, avg_feasible))
print("{} - step={}, min_valid_cost={:.6g}, time={:.3f}s".format(
strftime, step, min(min_valid_cost, avg_cost2), duration))
print("---------------------------------------------------------------------------------------")
sys.stdout.flush()
return avg_cost2
if saved_state is not None:
start_step = saved_state['step']
if not dist.is_initialized() or dist.get_rank() == 0:
torch.set_rng_state(saved_state['rng_state'])
if torch.cuda.is_available():
torch.cuda.set_rng_state(saved_state['cuda_rng_state'], self.device)
best_agent.load_state_dict(saved_state['best_agent_state'])
self.agent.load_state_dict(saved_state['best_agent_state'])
# if 'agent_state' in saved_state:
# self.agent.load_state_dict(saved_state['agent_state'])
# else:
# self.agent.load_state_dict(saved_state['best_agent_state'])
if 'optimizer_state' in saved_state:
optimizer.load_state_dict(saved_state['optimizer_state'])
if 'scheduler_state' in saved_state:
scheduler.load_state_dict(saved_state['scheduler_state'])
else:
if dist.is_initialized() and dist.get_rank() == 0:
rng_state = torch.get_rng_state()
cuda_rng_state = None
if self.device.type == 'cuda':
cuda_rng_state = torch.cuda.get_rng_state(self.device)
do_save_state(rng_state, cuda_rng_state)
if dist.is_initialized() and dist.get_rank() == 0:
dist.barrier()
for step in range(start_step, max_steps):
if step % valid_steps == 0:
valid_cost = valid_sched_save(step)
if valid_cost < min_valid_cost:
best_agent.load_state_dict(self.agent.state_dict())
min_valid_cost = valid_cost
start_time = time.time()
# problem
with torch.no_grad():
problem = next(train_dataset)
if step < warmup_steps:
batch_size_now = batch_size // 2
else:
batch_size_now = batch_size
problem = self.to_device(problem)
if not on_policy:
data_agent = best_agent
else:
data_agent = self.agent
data_agent.eval()
# solution
if topk_size > 1:
with torch.no_grad():
batch_size_topk = batch_size_now * topk_size
env, logp = data_agent(problem, batch_size_topk)
cost = env.cost().sum(1).float()
solution = env.worker_task_sequence()
NP = problem.batch_size
NK = batch_size_now // NP
NS = solution.size(1)
cost = cost.view(NP, -1)
cost, kidx = cost.topk(NK, 1, False, False)
cost = cost.view(-1)
kidx = kidx[:, :, None, None].expand(-1, -1, NS, 3)
solution = solution.view(NP, -1, NS, 3)
solution = solution.gather(1, kidx).view(-1, NS, 3)
elif not on_policy:
with torch.no_grad():
env, logp = data_agent(problem, batch_size_now)
cost = env.cost().sum(1).float()
solution = env.worker_task_sequence()
else:
self.agent.train()
env, logp = self.agent(problem, batch_size_now, memopt=memopt)
cost = env.cost().sum(1).float()
solution = env.worker_task_sequence()
self.agent.train()
# advantage
with torch.no_grad():
NP = problem.batch_size
if topk_size > 1:
baseline = cost.view(NP, -1).max(1)[0]
else:
baseline = cost.view(NP, -1).mean(1)
baseline = repeat(baseline, cost.size(0) // NP)
adv = (cost - baseline)[:, None]
adv_norm = adv.norm()
if adv_norm > 0:
adv = adv / adv.norm() * adv.size(0)
adv = adv.sign() * adv.abs().pow(advpow)
# backward
if topk_size > 1 or not on_policy:
env, logp = self.agent(problem, batch_size_now, solution=solution, memopt=memopt)
loss = adv * logp
loss = loss.mean()
loss.backward()
if step % optim_steps == 0:
if dist.is_initialized():
params = filter(lambda a: a.grad is not None, self.agent.parameters())
grad_list = [param.grad for param in params]
grad_vector = parameters_to_vector(grad_list)
dist.all_reduce(grad_vector, op=dist.ReduceOp.SUM)
vector_to_parameters(grad_vector, grad_list)
grad_norm = clip_grad_norm_(self.agent.parameters(), 1)
optimizer.step()
optimizer.zero_grad()
if step % log_steps == 0:
strftime = time.strftime("%Y-%m-%d %H:%M:%S")
lr = optimizer.param_groups[0]['lr']
duration = time.time() - start_time
with torch.no_grad():
p = logp.to(torch.float64).sum(1).exp().mean()
print("{} - step={}, grad={:.6g}, lr={:.6g}, p={:.6g}".format(
strftime, step, grad_norm, lr, p))
print("{} - step={}, cost={:.6g}, time={:.3f}s".format(strftime, step, cost.mean(), duration))
print("---------------------------------------------------------------------------------------")
sys.stdout.flush()
valid_sched_save(step)
def solve(self, problem, greedy=False, batch_size=512):
if self.device.type == 'cuda':
with torch.cuda.device(cuda_or_none(self.device)):
return self.do_solve(problem, greedy, batch_size)
else:
return self.do_solve(problem, greedy, batch_size)
def do_solve(self, problem, greedy, batch_size):
isbatch = problem.isbatch
problem = self.to_batch(problem)
problem = self.to_device(problem)
if self.agent is None:
self.parse_nn_args(problem)
self.agent = self.new_agent()
self.agent.to(self.device)
self.agent.eval()
with torch.no_grad():
env, prob = self.agent(problem, batch_size, greedy, problem.solution)
NP = problem.batch_size
NR = prob.size(0) // NP
prob = prob.view(NP, NR, -1)
cost = env.cost().sum(1).view(NP, NR)
feasible = env.feasible().view(NP, NR)
size = list(env.worker_task_sequence().size())
size = [NP, NR] + size[1:]
worker_task_sequence = env.worker_task_sequence().view(size)
p_index = torch.arange(NP)
base_cost = cost.max() + 1
cost[~feasible] += base_cost
cost, s_index = cost.min(1)
feasible = feasible[p_index, s_index]
cost[~feasible] -= base_cost
probability = prob[p_index, s_index].exp()
worker_task_sequence = worker_task_sequence[p_index, s_index]
if isbatch:
solution = Solution(cost)
solution.feasible = feasible
solution.probability = probability
solution.worker_task_sequence = worker_task_sequence
else:
solution = Solution(cost.item())
solution.feasible = feasible.item()
solution.probability = probability.squeeze(0)
solution.worker_task_sequence = worker_task_sequence.squeeze(0)
return solution
def load_agent(self, filename, strict=True):
if self.device.type == 'cuda':
with torch.cuda.device(cuda_or_none(self.device)):
self.do_load_agent(filename, strict)
else:
self.do_load_agent(filename, strict)
def do_load_agent(self, filename, strict=True):
saved_state = torch.load(filename, map_location='cpu')
self.nn_args = saved_state['nn_args']
self.agent = self.new_agent()
self.agent.to(self.device)
self.agent.load_state_dict(saved_state['best_agent_state'], strict)
self.print_nn_args()
def to_batch(self, problem, pin_memory=True):
assert not hasattr(problem, 'feats')
NW = 1
NT = 1
NP = 1
isbatch = problem.isbatch
for k, v in problem.__dict__.items():
if k.startswith("worker_"):
NW = len(v[0]) if isbatch else len(v)
elif k.startswith("task_"):
NP = len(v) if isbatch else 1
NT = len(v[0]) if isbatch else len(v)
NWW = NW * 2
new_problem = Problem(True)
new_problem.feats = {}
new_problem.device = 'cpu'
new_problem.batch_size = NP
new_problem.worker_num = NW
new_problem.task_num = NT
new_problem.features = problem.features
if type(self) == Solver:
new_problem.variables = problem.variables
new_problem.constraint = problem.constraint
new_problem.objective = problem.objective
new_problem.environment = problem.environment
else:
new_problem.variables = []
new_problem.constraints = problem.constraints
new_problem.oa_estimate_tasks = problem.oa_estimate_tasks
new_problem.oa_multiple_steps = problem.oa_multiple_steps
edge_size_list = ((NWW + NT, NWW + NT), (NW + NT, NW + NT))
def check_size(f, k, v):
assert f, "size error, feature: {}, size: {}".format(k, tuple(v.size()))
for k, v in problem.__dict__.items():
if k == 'solution' and v is not None:
v = to_tensor(k, v, isbatch)
check_size(v.dim() == 3 and v.size(-1) == 3, k, v)
elif k.startswith("worker_task_"):
v = to_tensor(k, v, isbatch)
check_size(v.dim() in (3, 4) and v.size()[1:3] == (NW, NT), k, v)
elif k.startswith("worker_"):
v = to_tensor(k, v, isbatch)
check_size(v.dim() in (2, 3) and v.size(1) == NW, k, v)
elif k.startswith("task_"):
v = to_tensor(k, v, isbatch)
check_size(v.dim() in (2, 3) and v.size(1) == NT, k, v)
elif k.endswith("_matrix"):
v = to_tensor(k, v, isbatch)
check_size(v.dim() in (3, 4) and v.size()[1:3] in edge_size_list, k, v)
if v.size()[1:3] == (NW + NT, NW + NT):
worker_index = torch.arange(NW)
task_index = torch.arange(NT) + NW
index = torch.cat([worker_index, worker_index, task_index])
index1 = index[:, None]
index2 = index[None, :]
v = v[:, index1, index2]
elif isinstance(v, np.ndarray):
v = torch.tensor(v)
if isinstance(v, torch.Tensor):
new_problem.feats[k] = v
if pin_memory and self.device.type == 'cuda':
new_problem.pin_memory()
return new_problem
def to_device(self, problem):
assert hasattr(problem, 'feats')
new_problem = copy.copy(problem)
new_problem.device = self.device
new_problem.feats = {}
non_blocking = self.device.type == 'cuda'
for k, v in problem.feats.items():
v = v.to(self.device, non_blocking=non_blocking)
new_problem.feats[k] = v
return new_problem
def validate(self, problem_list, batch_size):
self.agent.eval()
with torch.no_grad():
valid_result = self.do_validate(problem_list, batch_size)
self.agent.train()
return valid_result
def do_validate(self, problem_list, batch_size):
total_cost1 = 0
total_cost2 = 0
total_feasible = 0
total_problem = 0
start_time = time.time()
for problem in problem_list:
problem = self.to_device(problem)
env, _, = self.agent(problem, batch_size)
NP = problem.batch_size
cost = env.cost().sum(1).view(NP, -1)
cost1, _ = cost.min(1)
cost2 = cost.mean(1)
feasible = env.feasible().view(NP, -1)
feasible = torch.any(feasible, 1)
total_cost1 += cost1.sum().item()
total_cost2 += cost2.sum().item()
total_feasible += feasible.int().sum().item()
total_problem += NP
if dist.is_initialized():
data = [total_cost1, total_cost2, total_feasible, total_problem]
data = torch.tensor(data, device=self.device)
dist.all_reduce(data, op=dist.ReduceOp.SUM)
total_cost1, total_cost2, total_feasible, total_problem = data.tolist()
avg_cost1 = total_cost1 / total_problem
avg_cost2 = total_cost2 / total_problem
avg_feasible = total_feasible / total_problem
return avg_cost1, avg_cost2, avg_feasible
def wrap_dataset(self, dataset, workers, buffers, seed):
if isinstance(dataset, IterableDataset):
dataset = WrapIterableDataset(dataset, self)
dataset = DataLoader(dataset, batch_size=None, pin_memory=True,
num_workers=workers, prefetch_factor=buffers,
worker_init_fn=lambda worker_id: torch.manual_seed(seed + worker_id))
else:
if self.device.type == 'cuda':
with torch.cuda.device(cuda_or_none(self.device)):
dataset = WrapDataset(dataset, self)
dataset = DataLoader(dataset, batch_size=None, pin_memory=True, shuffle=True)
else:
dataset = WrapDataset(dataset, self)
dataset = DataLoader(dataset, batch_size=None, pin_memory=True, shuffle=True)
return dataset
def print_nn_args(self):
for key, value in self.nn_args.items():
if type(value) in [int, float, str, bool]:
print("nn_args: {} = {}".format(key, value))
sys.stdout.flush()
def to_tensor(key, value, isbatch):
if isinstance(value, torch.Tensor):
tensor = value.to('cpu')
else:
tensor = torch.tensor(value, device='cpu')
if not isbatch:
tensor = tensor[None]
return tensor
def cuda_or_none(device):
return device if device.type == 'cuda' else None