szukevin's picture
upload
7900c16
raw
history blame
773 Bytes
import torch.nn as nn
class Target(nn.Module):
def __init__(self):
super(Target, self).__init__()
self.target_name_list = []
self.loss_info = {}
def update(self, target, target_name):
setattr(self, target_name, target)
self.target_name_list.append(target_name)
def forward(self, memory_bank, tgt, seg):
self.loss_info = {}
for i, target_name in enumerate(self.target_name_list):
target = getattr(self, target_name)
if len(self.target_name_list) > 1:
self.loss_info[self.target_name_list[i]] = target(memory_bank, tgt[self.target_name_list[i]], seg)
else:
self.loss_info = target(memory_bank, tgt, seg)
return self.loss_info