|
import torch |
|
|
|
from torch import nn |
|
from collections import OrderedDict |
|
from torch.utils.checkpoint import checkpoint |
|
from .feature import * |
|
from .pyenv import PyEnv |
|
from .encode import Encode |
|
from .decode import Decode |
|
|
|
|
|
class Agent(nn.Module): |
|
|
|
def __init__(self, nn_args): |
|
super(Agent, self).__init__() |
|
|
|
self.nn_args = nn_args |
|
self.vars_dim = sum(nn_args['variable_dim'].values()) |
|
self.steps_ratio = nn_args.setdefault('decode_steps_ratio', 1.0); |
|
|
|
logit_clips = nn_args.setdefault('decode_logit_clips', 10.0); |
|
if isinstance(logit_clips, str): |
|
self.logit_clips = [float(v) for v in logit_clips.split(',')] |
|
else: |
|
self.logit_clips = [float(logit_clips)] |
|
|
|
self.nn_encode = Encode(nn_args) |
|
self.nn_decode = Decode(nn_args) |
|
|
|
def nn_args_dict(self): |
|
return self.nn_args |
|
|
|
def forward(self, problem, batch_size, greedy=False, solution=None, memopt=0): |
|
X, K, V = self.nn_encode(problem.feats, problem.batch_size, |
|
problem.worker_num, problem.task_num, memopt) |
|
|
|
return self.interact(problem, X, K, V, batch_size, greedy, solution, memopt) |
|
|
|
def interact(self, problem, X, K, V, batch_size, greedy, solution, memopt): |
|
NP = problem.batch_size |
|
NW = problem.worker_num |
|
NT = problem.task_num |
|
|
|
sample_num = batch_size // NP |
|
assert sample_num > 0 and batch_size % NP == 0 |
|
|
|
MyEnv = problem.environment |
|
if MyEnv is None: |
|
env = PyEnv(problem, batch_size, sample_num, self.nn_args) |
|
else: |
|
env = MyEnv(str(problem.device), problem.feats, batch_size, |
|
sample_num, problem.worker_num, problem.task_num) |
|
|
|
query = X.new_zeros(batch_size, X.size(-1)) |
|
state1 = X.new_zeros(batch_size, X.size(-1)) |
|
state2 = X.new_zeros(batch_size, X.size(-1)) |
|
|
|
p_list = [] |
|
NULL = X.new_ones(0) |
|
p_index = torch.div(torch.arange(batch_size, device=X.device), sample_num, rounding_mode='trunc') |
|
if solution is not None: |
|
solution = solution[:, :, 0:2].to(torch.int64).permute(1, 0, 2) |
|
assert torch.all(solution >= 0) and solution.size(1) == batch_size |
|
offset = torch.tensor([0, NW, NW + NW, NW + NW + NT], device=X.device) |
|
chosen_list = solution[:, :, 1] + offset[solution[:, :, 0]] |
|
|
|
mode = 0 |
|
sample_p = torch.rand(batch_size, device=X.device) |
|
for chosen in chosen_list: |
|
env_time = env.time() |
|
clip = self.logit_clips[min(env_time, len(self.logit_clips) - 1)] |
|
varfeat = env.make_feat() if self.vars_dim > 0 else NULL |
|
state1, state2, chosen_p = self.decode(X, K, V, query, state1, state2, |
|
varfeat, env.mask(), chosen, sample_p, clip, mode, memopt) |
|
query = X[p_index, chosen] |
|
p_list.append(chosen_p) |
|
env.step(chosen) |
|
|
|
assert env.all_finished(), 'not all finished!' |
|
else: |
|
mode = 1 if greedy else 2 |
|
min_env_time = int(self.steps_ratio * NT) |
|
R = torch.rand(NT * 2, batch_size, device=X.device) |
|
while True: |
|
env_time = env.time() |
|
if env_time > min_env_time and env_time % 3 == 0 and env.all_finished(): |
|
break |
|
|
|
clip = self.logit_clips[min(env_time, len(self.logit_clips) - 1)] |
|
sample_p = R[env_time % R.size(0)] |
|
chosen = X.new_empty(batch_size, dtype=torch.int64) |
|
varfeat = env.make_feat() if self.vars_dim > 0 else NULL |
|
state1, state2, chosen_p = self.decode(X, K, V, query, state1, state2, |
|
varfeat, env.mask(), chosen, sample_p, clip, mode, memopt) |
|
query = X[p_index, chosen] |
|
p_list.append(chosen_p) |
|
env.step(chosen) |
|
|
|
env.finalize() |
|
return env, torch.stack(p_list, 1) |
|
|
|
def decode(self, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p, clip, mode, memopt): |
|
run_fn = self.decode_fn(clip, mode, memopt) |
|
if self.training and memopt > 3: |
|
return checkpoint(run_fn, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p) |
|
else: |
|
return run_fn(X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p) |
|
|
|
def decode_fn(self, clip, mode, memopt): |
|
memopt = 0 if memopt > 3 else memopt |
|
|
|
def run_fn(X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p): |
|
return self.nn_decode(X, K, V, query, state1, state2, |
|
varfeat, mask, chosen, sample_p, clip, mode, memopt) |
|
|
|
return run_fn |
|
|
|
|
|
def parse_nn_args(problem, nn_args): |
|
worker_dim = OrderedDict() |
|
task_dim = OrderedDict() |
|
edge_dim = OrderedDict() |
|
variable_dim = OrderedDict() |
|
embed_dict = OrderedDict() |
|
|
|
def set_dim_by_name(name, k, dim): |
|
if name.startswith("worker_task_"): |
|
edge_dim[k] = dim |
|
elif name.startswith("worker_"): |
|
worker_dim[k] = dim |
|
elif name.startswith("task_"): |
|
task_dim[k] = dim |
|
elif name.endswith("_matrix"): |
|
edge_dim[k] = dim |
|
else: |
|
raise Exception("attribute can't be feature: {}".format(k)) |
|
|
|
feature_dict = make_feat_dict(problem) |
|
variables = [var(problem, problem.batch_size, 1) for var in problem.variables] |
|
variable_dict = dict([(var.name, var) for var in variables]) |
|
for k, f in feature_dict.items(): |
|
if isinstance(f, VariableFeature): |
|
var = variable_dict[f.name] |
|
assert hasattr(var, 'make_feat'), \ |
|
"{} cann't be variable feature, name:{}".format(type(var).__name__, k) |
|
v = var.make_feat() |
|
if v.dim() == 2: |
|
variable_dim[k] = 1 |
|
else: |
|
variable_dim[k] = v.size(-1) |
|
elif isinstance(f, SparseLocalFeature): |
|
edge_dim[k] = 1 |
|
set_dim_by_name(f.value, k, 1) |
|
elif isinstance(f, LocalFeature): |
|
edge_dim[k] = 1 |
|
set_dim_by_name(f.name, k, 1) |
|
elif isinstance(f, LocalCategory): |
|
edge_dim[k] = 1 |
|
elif isinstance(f, GlobalCategory): |
|
set_dim_by_name(f.name, k, nn_args.setdefault('encode_hidden_dim', 128)) |
|
embed_dict[k] = f.size |
|
elif isinstance(f, ContinuousFeature): |
|
v = problem.feats[k] |
|
if k.startswith("worker_task_") or k.endswith("_matrix"): |
|
simple_dim = 3 |
|
else: |
|
simple_dim = 2 |
|
|
|
if v.dim() == simple_dim: |
|
set_dim_by_name(f.name, k, 1) |
|
else: |
|
set_dim_by_name(f.name, k, v.size(-1)) |
|
else: |
|
raise Exception("unsupported feature type: {}".format(type(f))) |
|
|
|
nn_args['worker_dim'] = worker_dim |
|
nn_args['task_dim'] = task_dim |
|
nn_args['edge_dim'] = edge_dim |
|
nn_args['variable_dim'] = variable_dim |
|
nn_args['embed_dict'] = embed_dict |
|
nn_args['feature_dict'] = feature_dict |
|
return nn_args |
|
|
|
|
|
def make_feat_dict(problem): |
|
feature_dict = OrderedDict() |
|
|
|
def add(k, f): |
|
_f = feature_dict.get(k) |
|
if _f is None or _f == f: |
|
feature_dict[k] = f |
|
else: |
|
"duplicated feature, name: {}, feature1: {}, feature2: {}".format(k, _f, f) |
|
|
|
for f in problem.features: |
|
if isinstance(f, VariableFeature): |
|
add(':'.join(['var', f.name]), f) |
|
elif isinstance(f, SparseLocalFeature): |
|
add(':'.join([f.index, f.value]), f) |
|
else: |
|
add(f.name, f) |
|
|
|
return feature_dict |
|
|