|
import math |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from torch import nn |
|
from torch.utils.checkpoint import checkpoint |
|
from .norm import Norm1D, Norm2D |
|
from .dense import Dense |
|
from .utils import repeat |
|
from .feature import * |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__(self, heads, hidden_dim): |
|
super(MultiHeadAttention, self).__init__() |
|
|
|
assert hidden_dim % heads == 0 |
|
|
|
self.heads = heads |
|
head_dim = hidden_dim // heads |
|
self.alpha = 1 / math.sqrt(head_dim) |
|
|
|
self.nn_Q = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim)) |
|
self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim)) |
|
self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim)) |
|
self.nn_O = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim)) |
|
|
|
for param in self.parameters(): |
|
stdv = 1. / math.sqrt(param.size(-1)) |
|
param.data.uniform_(-stdv, stdv) |
|
|
|
def forward(self, x, edge): |
|
batch_size, item_num, hidden_dim = x.size() |
|
size = (self.heads, batch_size, item_num, -1) |
|
|
|
x = x.reshape(-1, hidden_dim) |
|
Q = torch.matmul(x, self.nn_Q).view(size) |
|
K = torch.matmul(x, self.nn_K).view(size) |
|
V = torch.matmul(x, self.nn_V).view(size) |
|
|
|
heads_batch = self.heads * batch_size |
|
Q = Q.view(heads_batch, item_num, -1) |
|
K = K.view(heads_batch, item_num, -1).transpose(1, 2) |
|
|
|
if edge is not None: |
|
S = edge.view(heads_batch, item_num, item_num) |
|
S = S.baddbmm(Q, K, alpha=self.alpha) |
|
else: |
|
S = Q.new_zeros(heads_batch, item_num, item_num) |
|
S = S.baddbmm_(Q, K, alpha=self.alpha) |
|
|
|
S = S.view(self.heads, batch_size, item_num, item_num) |
|
|
|
S = F.softmax(S, dim=-1) |
|
|
|
x = torch.matmul(S, V).permute(1, 2, 0, 3) |
|
x = x.reshape(batch_size, item_num, -1) |
|
x = torch.matmul(x, self.nn_O) |
|
return x |
|
|
|
|
|
class Encode(nn.Module): |
|
def __init__(self, nn_args): |
|
super(Encode, self).__init__() |
|
|
|
self.nn_args = nn_args |
|
self.worker_dim = nn_args['worker_dim'] |
|
self.task_dim = nn_args['task_dim'] |
|
self.edge_dim = nn_args['edge_dim'] |
|
|
|
self.embed_dict = nn_args['embed_dict'] |
|
self.feature_dict = nn_args['feature_dict'] |
|
|
|
layers = nn_args.setdefault('encode_layers', 3) |
|
heads = nn_args.setdefault('encode_atten_heads', 8) |
|
norm = nn_args.setdefault('encode_norm', 'instance') |
|
hidden_dim = nn_args.setdefault('encode_hidden_dim', 128) |
|
output_dim = nn_args.setdefault('decode_hidden_dim', 128) |
|
output_heads = nn_args.setdefault('decode_atten_heads', 0) |
|
|
|
self.heads = heads |
|
self.layers = layers |
|
|
|
worker_dim = max(1, sum(self.worker_dim.values())) |
|
task_dim = max(1, sum(self.task_dim.values())) |
|
|
|
self.nn_dense_worker_start = Dense(worker_dim, hidden_dim) |
|
self.nn_dense_worker_end = Dense(worker_dim, hidden_dim) |
|
self.nn_dense_task = Dense(task_dim, hidden_dim) |
|
|
|
self.nn_norm_worker_task = Norm1D(hidden_dim, norm, True) |
|
|
|
if len(self.edge_dim) > 0: |
|
edge_dim = sum(self.edge_dim.values()) |
|
self.nn_dense_edge = Dense(edge_dim, heads) |
|
self.nn_norm_edge = Norm2D(heads, norm, True) |
|
|
|
nn_embed_dict = {} |
|
for k, v in self.embed_dict.items(): |
|
nn_embed_dict[k] = nn.Embedding(v, hidden_dim) |
|
self.nn_embed_dict = nn.ModuleDict(nn_embed_dict) |
|
|
|
self.nn_attens = nn.ModuleList() |
|
self.nn_denses = nn.ModuleList() |
|
self.nn_norms1 = nn.ModuleList() |
|
self.nn_norms2 = nn.ModuleList() |
|
for i in range(layers): |
|
self.nn_attens.append(MultiHeadAttention(heads, hidden_dim)) |
|
self.nn_denses.append(nn.Sequential( |
|
Dense(hidden_dim, hidden_dim * 4), |
|
Dense(hidden_dim * 4, hidden_dim, act='relu'), |
|
)) |
|
self.nn_norms1.append(Norm1D(hidden_dim, norm, True)) |
|
self.nn_norms2.append(Norm1D(hidden_dim, norm, True)) |
|
|
|
self.nn_finish = nn.Parameter(torch.Tensor(1, 1, hidden_dim)) |
|
|
|
if output_dim != hidden_dim: |
|
self.nn_X = nn.Parameter(torch.Tensor(hidden_dim, output_dim)) |
|
else: |
|
self.nn_X = None |
|
|
|
if output_heads > 0: |
|
assert output_dim % output_heads == 0 |
|
head_dim = output_dim // output_heads |
|
self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim)) |
|
self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim)) |
|
else: |
|
self.nn_K = None |
|
self.nn_V = None |
|
|
|
for param in self.parameters(): |
|
stdv = 1 / math.sqrt(param.size(-1)) |
|
param.data.uniform_(-stdv, stdv) |
|
|
|
def forward(self, problem, batch_size, worker_num, task_num, memopt=0): |
|
worker_start, worker_end = self.encode_worker(problem, batch_size, worker_num) |
|
task = self.encode_task(problem, batch_size, task_num) |
|
X = torch.cat([worker_start, worker_end, task], 1) |
|
X = self.nn_norm_worker_task(X) |
|
|
|
if len(self.edge_dim) > 0: |
|
edge = self.encode_edge(problem, batch_size, worker_num, task_num) |
|
edge = self.nn_norm_edge(edge) |
|
edge = edge.permute(3, 0, 1, 2).contiguous() |
|
else: |
|
edge = None |
|
|
|
|
|
for i in range(self.layers): |
|
X = self.encode_layer(X, edge, i, memopt) |
|
|
|
finish = repeat(self.nn_finish, X.size(0)) |
|
X = torch.cat([X, finish], 1) |
|
if self.nn_X is not None: |
|
X = torch.matmul(X, self.nn_X) |
|
|
|
if self.nn_K is not None: |
|
batch_size, item_num, hidden_dim = X.size() |
|
size = (self.heads, batch_size, item_num, -1) |
|
X2 = X.reshape(-1, hidden_dim) |
|
K = torch.matmul(X2, self.nn_K).view(size) |
|
V = torch.matmul(X2, self.nn_V).view(size) |
|
else: |
|
K = torch.ones(0) |
|
V = torch.ones(0) |
|
return X, K, V |
|
|
|
def encode_layer(self, X, edge, i, memopt): |
|
run_fn = self.encode_layer_fn(i, memopt) |
|
if self.training and memopt > 6: |
|
return checkpoint(run_fn, X, edge) |
|
else: |
|
return run_fn(X, edge) |
|
|
|
def encode_layer_fn(self, i, memopt): |
|
def run_fn(X, edge): |
|
if self.training and memopt == 6: |
|
X = X + checkpoint(self.nn_attens[i], X, edge) |
|
else: |
|
X = X + self.nn_attens[i](X, edge) |
|
X = self.nn_norms1[i](X) |
|
|
|
X = X + self.nn_denses[i](X) |
|
X = self.nn_norms2[i](X) |
|
return X |
|
|
|
return run_fn |
|
|
|
def encode_worker(self, problem, batch_size, worker_num): |
|
feature_list = [] |
|
for k, dim in self.worker_dim.items(): |
|
f = self.feature_dict.get(k) |
|
if isinstance(f, GlobalCategory): |
|
v = problem[f.name] |
|
v = self.nn_embed_dict[k](v.long()) |
|
elif isinstance(f, ContinuousFeature): |
|
v = problem[f.name] |
|
else: |
|
raise Exception("unsupported feature type: {}".format(type(f))) |
|
|
|
if v.dim() == 2: |
|
v = v[:, :, None] |
|
|
|
assert dim == v.size(-1), \ |
|
"feature dim error, feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1)) |
|
|
|
feature_list.append(v.float()) |
|
|
|
if feature_list: |
|
x = torch.cat(feature_list, 2) |
|
else: |
|
x = self.nn_finish.new_ones(batch_size, worker_num, 1) |
|
return self.nn_dense_worker_start(x), self.nn_dense_worker_end(x) |
|
|
|
def encode_task(self, problem, batch_size, task_num): |
|
feature_list = [] |
|
for k, dim in self.task_dim.items(): |
|
f = self.feature_dict.get(k) |
|
if isinstance(f, SparseLocalFeature): |
|
v = problem[f.value] |
|
assert v.dim() == 3, \ |
|
"sparse local feature's dimension must 2, feature:{}".format(k) |
|
v = v.clamp(0, 1).sum(2, dtype=v.dtype) |
|
elif isinstance(f, GlobalCategory): |
|
v = problem[f.name] |
|
v = self.nn_embed_dict[k](v.long()) |
|
elif isinstance(f, LocalFeature): |
|
v = problem[f.name] |
|
assert v.dim() == 3, \ |
|
"local feature's dimension must 2, feature:{}".format(k) |
|
v = v.clamp(0, 1).sum(2, dtype=v.dtype) |
|
elif isinstance(f, ContinuousFeature): |
|
v = problem[f.name] |
|
else: |
|
raise Exception("unsupported feature type: {}".format(type(f))) |
|
|
|
if v.dim() == 2: |
|
v = v[:, :, None] |
|
|
|
assert dim == v.size(-1), \ |
|
"feature dim error, feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1)) |
|
|
|
feature_list.append(v.float()) |
|
|
|
if feature_list: |
|
x = torch.cat(feature_list, 2) |
|
else: |
|
x = self.nn_finish.new_ones(batch_size, task_num, 1) |
|
return self.nn_dense_task(x) |
|
|
|
def encode_edge(self, problem, batch_size, worker_num, task_num): |
|
NP = batch_size |
|
NW = worker_num |
|
NT = task_num |
|
NWW = NW + NW |
|
feature_list = [] |
|
for k, dim in self.edge_dim.items(): |
|
f = self.feature_dict.get(k) |
|
if isinstance(f, LocalCategory): |
|
assert f.name.startswith("task_") |
|
|
|
v = problem[k] |
|
v1 = v[:, :, None] |
|
v2 = v[:, None, :] |
|
|
|
v = torch.zeros(NP, NWW + NT, NWW + NT, |
|
dtype=v.dtype, device=v.device) |
|
v[:, NWW:, NWW:] = ((v1 == v2) & (v1 >= 0)) |
|
elif isinstance(f, LocalFeature): |
|
assert f.name.startswith("task_") |
|
|
|
v = problem[k].float() |
|
dot_product = torch.matmul(v, v.transpose(-1, -2)) |
|
v_norm = v.norm(dim=2) + 1e-10 |
|
v1_norm = v_norm[:, :, None] |
|
v2_norm = v_norm[:, None, :] |
|
|
|
v = torch.zeros(NP, NWW + NT, NWW + NT, |
|
dtype=v.dtype, device=v.device) |
|
v[:, NWW:, NWW:] = dot_product / v1_norm / v2_norm |
|
elif isinstance(f, SparseLocalFeature): |
|
assert NP == 1 |
|
assert f.index.startswith("task_") |
|
assert f.value.startswith("task_") |
|
|
|
index = problem[f.index] |
|
value = problem[f.value].float() |
|
|
|
NV = index.max().item() + 1 |
|
spv = value.reshape(-1).tolist() |
|
spi = index.reshape(-1).tolist() |
|
|
|
device = value.device |
|
spj = torch.arange(NT, device=device) |
|
spj = spj[:, None].expand_as(index) |
|
spj = spj.reshape(-1).tolist() |
|
|
|
value1 = torch.sparse_coo_tensor([spj, spi], spv, (NT, NV), device=device) |
|
value2 = torch.sparse_coo_tensor([spi, spj], spv, (NV, NT), device=device) |
|
|
|
value1 = value1.coalesce() |
|
value2 = value2.coalesce() |
|
cosine = torch.sparse.mm(value1, value2).to_dense() |
|
|
|
norm = value.norm(dim=-1).reshape(-1) |
|
norm1 = norm[:, None].expand(-1, NT) |
|
norm2 = norm[None, :].expand(NT, -1) |
|
cosine = cosine / (norm1 * norm2 + 1e-10) |
|
|
|
v = torch.zeros(NP, NWW + NT, NWW + NT, |
|
dtype=value.dtype, device=value.device) |
|
v[:, NWW:, NWW:] = cosine |
|
|
|
elif isinstance(f, ContinuousFeature): |
|
if f.name.endswith("_matrix"): |
|
v = problem[k] |
|
elif f.name.startswith("worker_task_"): |
|
v = problem[k] |
|
if v.dim() == 3: |
|
new_v = torch.zeros(NP, NWW + NT, NWW + NT, |
|
dtype=v.dtype, device=v.device) |
|
else: |
|
new_v = torch.zeros(NP, NWW + NT, NWW + NT, v.size(3), |
|
dtype=v.dtype, device=v.device) |
|
problem_index = torch.arange(NP, device=v.device)[:, None, None] |
|
worker_index = torch.arange(NW, device=v.device)[None, :, None] |
|
task_index = torch.arange(NT, device=v.device)[None, None, :] + NW + NW |
|
new_v[problem_index, worker_index, task_index] = v |
|
new_v[problem_index, task_index, worker_index] = v |
|
new_v[problem_index, worker_index + NW, task_index] = v |
|
new_v[problem_index, task_index, worker_index + NW] = v |
|
v = new_v |
|
else: |
|
raise Exception("feature: {}".format(f.name)) |
|
else: |
|
raise Exception("feature: {}, type: {}".format(k, type(f))) |
|
|
|
if v.dim() == 3: |
|
v = v[:, :, :, None] |
|
|
|
assert dim == v.size(-1), \ |
|
"feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1)) |
|
|
|
feature_list.append(v.float()) |
|
|
|
x = torch.cat(feature_list, 3) |
|
return self.nn_dense_edge(x) |
|
|
|
|