import torch import functools from .utils import repeat class VarMeta(object): def __init__(self, clazz, **kwargs): self.clazz = clazz self._kwargs = kwargs for k, v in kwargs.items(): setattr(self, k, v) def __call__(self, problem, batch_size, sample_num): kwargs = self._kwargs.copy() kwargs['problem'] = problem.feats kwargs['batch_size'] = batch_size kwargs['sample_num'] = sample_num kwargs['worker_num'] = problem.worker_num kwargs['task_num'] = problem.task_num return self.clazz(**kwargs) def attribute_variable(name, attribute=None): return VarMeta(AttributeVariable, name=name, attribute=attribute) class AttributeVariable: def __init__(self, name, attribute, problem, batch_size, sample_num, worker_num, task_num): if attribute is None: attribute = name; self.name = name self.value = problem[attribute] def feature_variable(name, feature=None): return VarMeta(FeatureVariable, name=name, feature=feature) class FeatureVariable: def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): if feature is None: feature = name assert feature == 'id' or feature.startswith("worker_") or feature.startswith("task_") self.name = name self.feature = problem[feature] self.value = repeat(self.feature, sample_num) def task_variable(name, feature=None): return VarMeta(TaskVariable, name=name, feature=feature) class TaskVariable: def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): if feature is None: feature = name assert feature.startswith("task_") self.name = name self.feature = problem[feature] size = list(self.feature.size()) size[0] = batch_size del size[1] self.value = self.feature.new_zeros(size) def step_task(self, b_index, p_index, t_index): self.value[b_index] = self.feature[p_index, t_index] def worker_variable(name, feature=None): return VarMeta(WorkerVariable, name=name, feature=feature) class WorkerVariable: def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): if feature is None: feature = name assert feature.startswith("worker_") self.name = name self.feature = problem[feature] size = list(self.feature.size()) size[0] = batch_size del size[1] self.value = self.feature.new_zeros(size) def step_worker_start(self, b_index, p_index, w_index): self.value[b_index] = self.feature[p_index, w_index] def worker_task_variable(name, feature=None): return VarMeta(WorkerTaskVariable, name=name, feature=feature) class WorkerTaskVariable: def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): if feature is None: feature = name assert feature.startswith("worker_task_") self.name = name self.feature = problem[feature] size = list(self.feature.size()) size[0] = batch_size del size[1] self._feature = self.feature.new_zeros(size) del size[2] self.value = self.feature.new_zeros(size) def step_worker_start(self, b_index, p_index, w_index): self._feature[b_index] = self.feature[p_index, w_index] def step_task(self, b_index, p_index, t_index): self.value[b_index] = self._feature[b_index, t_index] def worker_task_group(name, feature=None): return VarMeta(WorkerTaskGroup, name=name, feature=feature) class WorkerTaskGroup: def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): if feature is None: feature = name assert feature.startswith("task_") self.name = name self.feature = problem[feature].long() NG = self.feature.max() + 1 assert torch.all(self.feature >= 0) self.value = self.feature.new_zeros(batch_size, NG) def step_worker_start(self, b_index, p_index, w_index): self.value[b_index] = 0 def step_task(self, b_index, p_index, t_index): group = self.feature[p_index, t_index] self.value[b_index, group] += 1; def worker_task_item(name, item_id, item_num): return VarMeta(WorkerTaskItem, name=name, item_id=item_id, item_num=item_num) class WorkerTaskItem: def __init__(self, name, item_id, item_num, problem, batch_size, sample_num, worker_num, task_num): assert item_id.startswith('task_') assert item_num.startswith('task_') self.name = name self.item_id = repeat(problem[item_id], sample_num).long() self.item_num = repeat(problem[item_num], sample_num) assert torch.all(self.item_id >= 0) size = [0, 0] size[0] = self.item_id.size(0) size[1] = self.item_id.max() + 1 self.value = self.item_num.new_zeros(size) def step_worker_start(self, b_index, p_index, w_index): self.value[b_index] = 0 def step_task(self, b_index, p_index, t_index): item_id = self.item_id[b_index, t_index] item_num = self.item_num[b_index, t_index] self.value[b_index[:, None], item_id] += item_num def make_feat(self): NT = self.item_id.size(1) v = self.value[:, None, :] v = v.expand(-1, NT, -1) v = v.gather(2, self.item_id).clamp(0, 1) v = self.item_num.clamp(0, 1) - v return v.clamp(0, 1).sum(2) def task_demand_now(name, feature=None, only_this=False): return VarMeta(TaskDemandNow, name=name, feature=feature, only_this=only_this) class TaskDemandNow: def __init__(self, name, feature, only_this, problem, batch_size, sample_num, worker_num, task_num): if feature is None: feature = name assert feature.startswith("task_") self.name = name self.only_this = only_this self._value = repeat(problem[feature], sample_num) assert self._value.dtype in \ (torch.int8, torch.int16, torch.int32, torch.int64) assert torch.all(self._value >= 0) if only_this: size = self._value.size(0) self.value = self._value.new_zeros(size) else: self.value = self._value def step_task(self, b_index, p_index, t_index, done): if done is not None: self._value[b_index, t_index] -= done if self.only_this: self.value[b_index] = self._value[b_index, t_index] else: self.value = self._value def worker_count_now(name, feature=None): return VarMeta(WorkerCountNow, name=name, feature=feature) class WorkerCountNow: def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): if feature is None: feature = name assert feature.startswith("worker_") self.name = name self.value = repeat(problem[feature], sample_num) assert self.value.dtype in \ (torch.int8, torch.int16, torch.int32, torch.int64) assert torch.all(self.value >= 0) def step_worker_start(self, b_index, p_index, w_index): self.value[b_index, w_index] -= 1 def edge_variable(name, feature, last_to_this=False, this_to_task=False, task_to_end=False, last_to_loop=False): return VarMeta(EdgeVariable, name=name, feature=feature, last_to_this=last_to_this, this_to_task=this_to_task, task_to_end=task_to_end, last_to_loop=last_to_loop) class EdgeVariable: def __init__(self, name, feature, last_to_this, this_to_task, task_to_end, last_to_loop, problem, batch_size, sample_num, worker_num, task_num): assert feature.endswith("_matrix") flags = [last_to_this, this_to_task, task_to_end, last_to_loop] assert flags.count(True) == 1 and flags.count(False) == 3 if feature is None: feature = name self.name = name self.last_to_this = last_to_this self.this_to_task = this_to_task self.task_to_end = task_to_end self.last_to_loop = last_to_loop self.worker_num = worker_num self.task_num = task_num self.feature = problem[feature] size = list(self.feature.size()) size[0] = batch_size del size[1:3] if self.this_to_task or self.task_to_end: size.insert(1, task_num) self.value = self.feature.new_zeros(size) else: self.value = self.feature.new_zeros(size) self.end_index = self.feature.new_zeros(size[0], dtype=torch.int64) self.loop_index = self.feature.new_zeros(size[0], dtype=torch.int64) self.last_index = self.feature.new_zeros(size[0], dtype=torch.int64) self.task_index = (torch.arange(task_num) + worker_num * 2)[None, :] def step_worker_start(self, b_index, p_index, w_index): if self.last_to_this: self.value[b_index] = 0 self.last_index[b_index] = w_index elif self.this_to_task: self.do_this_to_task(b_index, p_index, w_index) elif self.task_to_end: self.end_index[b_index] = w_index + self.worker_num self.do_task_to_end(b_index, p_index) elif self.last_to_loop: self.value[b_index] = 0 self.last_index[b_index] = w_index def step_worker_end(self, b_index, p_index, w_index): this_index = w_index + self.worker_num if self.last_to_this: self.do_last_to_this(b_index, p_index, this_index) elif self.this_to_task: self.do_this_to_task(b_index, p_index, this_index) elif self.task_to_end: pass elif self.last_to_loop: self.do_last_to_loop(b_index, p_index) def step_task(self, b_index, p_index, t_index): this_index = t_index + self.worker_num * 2 if self.last_to_this: self.do_last_to_this(b_index, p_index, this_index) self.last_index[b_index] = this_index elif self.this_to_task: self.do_this_to_task(b_index, p_index, this_index) elif self.task_to_end: pass elif self.last_to_loop: last_index = self.last_index[b_index] loop_index = self.loop_index[b_index] self.loop_index[b_index] = torch.where(last_index < self.worker_num, this_index, loop_index) self.last_index[b_index] = this_index def do_last_to_this(self, b_index, p_index, this_index): last_index = self.last_index[b_index] self.value[b_index] = self.feature[p_index, last_index, this_index] def do_this_to_task(self, b_index, p_index, this_index): p_index2 = p_index[:, None] this_index2 = this_index[:, None] task_index2 = self.task_index self.value[b_index] = self.feature[p_index2, this_index2, task_index2] def do_task_to_end(self, b_index, p_index): p_index2 = p_index[:, None] task_index2 = self.task_index end_index = self.end_index[b_index] end_index2 = end_index[:, None] self.value[b_index] = self.feature[p_index2, task_index2, end_index2] def do_last_to_loop(self, b_index, p_index): loop_index = self.loop_index[b_index] last_index = self.last_index[b_index] self.value[b_index] = self.feature[p_index, last_index, loop_index] def make_feat(self): assert self.this_to_task or self.task_to_end, \ "one of [this_to_task, task_to_end] must be true" return self.value.clone() def worker_used_resource(name, edge_require=None, task_require=None, task_ready=None, worker_ready=None, task_due=None): return VarMeta(WorkerUsedResource, name=name, edge_require=edge_require, task_require=task_require, task_ready=task_ready, worker_ready=worker_ready, task_due=task_due) class WorkerUsedResource: def __init__(self, name, edge_require, task_require, task_ready, worker_ready, task_due, problem, batch_size, sample_num, worker_num, task_num): assert edge_require is None or edge_require.endswith("_matrix"), "unsupported edge: {}".format(edge_require) assert task_require is None or task_require.startswith("task_"), "unsupported task_require: {}".format( task_require) assert task_ready is None or task_ready.startswith("task_"), "unsupported task_service: {}".format(task_ready) assert worker_ready is None or worker_ready.startswith("worker_") and not worker_ready.startswith( "worker_task_") assert task_due is None or task_due.startswith("task_"), "unsupported task_due: {}".format(task_due) self.name = name self.worker_num = worker_num self.task_num = task_num if edge_require is None: self.edge_require = None else: self.edge_require = problem[edge_require] self.last_index = self.edge_require.new_zeros(batch_size, dtype=torch.int64) if task_require is None: self.task_require = None else: self.task_require = problem[task_require] self.task_require2 = repeat(self.task_require, sample_num) if task_ready is None: self.task_ready = None else: self.task_ready = problem[task_ready] if worker_ready is None: self.worker_ready = None else: self.worker_ready = problem[worker_ready] if task_due is None: self.task_due = None else: self.task_due = problem[task_due] tenors = [self.edge_require, self.task_require, self.task_ready, self.worker_ready] tenors = list(filter(lambda x: x is not None, tenors)) assert tenors, "at least one of edge_require, task_require, task_ready, worker_ready is required!" size = list(tenors[0].size()) size[0] = batch_size if self.edge_require is None: del size[1] else: del size[1:3] self.value = tenors[0].new_zeros(size) def step_worker_start(self, b_index, p_index, w_index): if self.worker_ready is None: self.value[b_index] = 0 else: self.value[b_index] = self.worker_ready[p_index, w_index] if self.edge_require is not None: self.last_index[b_index] = w_index def step_worker_end(self, b_index, p_index, w_index): if self.edge_require is not None: last_index = self.last_index[b_index] this_index = w_index + self.worker_num self.value[b_index] += self.edge_require[p_index, last_index, this_index] self.last_index[b_index] = this_index; def step_task(self, b_index, p_index, t_index, done): if done is None: if self.edge_require is not None: last_index = self.last_index[b_index] this_index = t_index + (self.worker_num * 2) self.value[b_index] += self.edge_require[p_index, last_index, this_index] self.last_index[b_index] = this_index if self.task_ready is not None: self.value[b_index] = torch.max(self.value[b_index], self.task_ready[p_index, t_index]) else: if self.task_require is not None: if self.value.dim() == 2: done = done[:, None] self.value[b_index] += self.task_require[p_index, t_index] * done def make_feat(self): assert self.value.dim() == 2, \ "value's dim must be 2, actual: {}".format(self.value.dim()) assert self.task_require is not None, "task_require is required" v = self.value[:, None, :] + self.task_require2 return v.clamp(0, 1).sum(2, dtype=v.dtype) def worker_task_sequence(name): return VarMeta(WorkerTaskSequence, name=name) class WorkerTaskSequence: def __init__(self, name, problem, batch_size, sample_num, worker_num, task_num): self.name = name self.value = None def step_finish(self, worker_task_seq): self.value = worker_task_seq