|
import logging |
|
import math |
|
from collections import OrderedDict |
|
|
|
import mmcv |
|
import numpy as np |
|
import torch |
|
from torchvision.utils import save_image |
|
|
|
from models.archs.fcn_arch import FCNHead |
|
from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding |
|
from models.archs.unet_arch import ShapeUNet |
|
from models.losses.accuracy import accuracy |
|
from models.losses.cross_entropy_loss import CrossEntropyLoss |
|
|
|
logger = logging.getLogger('base') |
|
|
|
|
|
class ParsingGenModel(): |
|
"""Paring Generation model. |
|
""" |
|
|
|
def __init__(self, opt): |
|
self.opt = opt |
|
self.device = torch.device('cuda') |
|
self.is_train = opt['is_train'] |
|
|
|
self.attr_embedder = ShapeAttrEmbedding( |
|
dim=opt['embedder_dim'], |
|
out_dim=opt['embedder_out_dim'], |
|
cls_num_list=opt['attr_class_num']).to(self.device) |
|
self.parsing_encoder = ShapeUNet( |
|
in_channels=opt['encoder_in_channels']).to(self.device) |
|
self.parsing_decoder = FCNHead( |
|
in_channels=opt['fc_in_channels'], |
|
in_index=opt['fc_in_index'], |
|
channels=opt['fc_channels'], |
|
num_convs=opt['fc_num_convs'], |
|
concat_input=opt['fc_concat_input'], |
|
dropout_ratio=opt['fc_dropout_ratio'], |
|
num_classes=opt['fc_num_classes'], |
|
align_corners=opt['fc_align_corners'], |
|
).to(self.device) |
|
|
|
self.init_training_settings() |
|
|
|
self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220], |
|
[250, 235, 215], [255, 250, 205], [211, 211, 211], |
|
[70, 130, 180], [127, 255, 212], [0, 100, 0], |
|
[50, 205, 50], [255, 255, 0], [245, 222, 179], |
|
[255, 140, 0], [255, 0, 0], [16, 78, 139], |
|
[144, 238, 144], [50, 205, 174], [50, 155, 250], |
|
[160, 140, 88], [213, 140, 88], [90, 140, 90], |
|
[185, 210, 205], [130, 165, 180], [225, 141, 151]] |
|
|
|
def init_training_settings(self): |
|
optim_params = [] |
|
for v in self.attr_embedder.parameters(): |
|
if v.requires_grad: |
|
optim_params.append(v) |
|
for v in self.parsing_encoder.parameters(): |
|
if v.requires_grad: |
|
optim_params.append(v) |
|
for v in self.parsing_decoder.parameters(): |
|
if v.requires_grad: |
|
optim_params.append(v) |
|
|
|
self.optimizer = torch.optim.Adam( |
|
optim_params, |
|
self.opt['lr'], |
|
weight_decay=self.opt['weight_decay']) |
|
self.log_dict = OrderedDict() |
|
self.entropy_loss = CrossEntropyLoss().to(self.device) |
|
|
|
def feed_data(self, data): |
|
self.pose = data['densepose'].to(self.device) |
|
self.attr = data['attr'].to(self.device) |
|
self.segm = data['segm'].to(self.device) |
|
|
|
def optimize_parameters(self): |
|
self.attr_embedder.train() |
|
self.parsing_encoder.train() |
|
self.parsing_decoder.train() |
|
|
|
self.attr_embedding = self.attr_embedder(self.attr) |
|
self.pose_enc = self.parsing_encoder(self.pose, self.attr_embedding) |
|
self.seg_logits = self.parsing_decoder(self.pose_enc) |
|
|
|
loss = self.entropy_loss(self.seg_logits, self.segm) |
|
|
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
self.log_dict['loss_total'] = loss |
|
|
|
def get_vis(self, save_path): |
|
img_cat = torch.cat([ |
|
self.pose, |
|
self.segm, |
|
], dim=3).detach() |
|
img_cat = ((img_cat + 1) / 2) |
|
|
|
img_cat = img_cat.clamp_(0, 1) |
|
|
|
save_image(img_cat, save_path, nrow=1, padding=4) |
|
|
|
def inference(self, data_loader, save_dir): |
|
self.attr_embedder.eval() |
|
self.parsing_encoder.eval() |
|
self.parsing_decoder.eval() |
|
|
|
acc = 0 |
|
num = 0 |
|
|
|
for _, data in enumerate(data_loader): |
|
pose = data['densepose'].to(self.device) |
|
attr = data['attr'].to(self.device) |
|
segm = data['segm'].to(self.device) |
|
img_name = data['img_name'] |
|
|
|
num += pose.size(0) |
|
with torch.no_grad(): |
|
attr_embedding = self.attr_embedder(attr) |
|
pose_enc = self.parsing_encoder(pose, attr_embedding) |
|
seg_logits = self.parsing_decoder(pose_enc) |
|
seg_pred = seg_logits.argmax(dim=1) |
|
acc += accuracy(seg_logits, segm) |
|
palette_label = self.palette_result(segm.cpu().numpy()) |
|
palette_pred = self.palette_result(seg_pred.cpu().numpy()) |
|
pose_numpy = ((pose[0] + 1) / 2. * 255.).expand( |
|
3, |
|
pose[0].size(1), |
|
pose[0].size(2), |
|
).cpu().numpy().clip(0, 255).astype(np.uint8).transpose(1, 2, 0) |
|
concat_result = np.concatenate( |
|
(pose_numpy, palette_pred, palette_label), axis=1) |
|
mmcv.imwrite(concat_result, f'{save_dir}/{img_name[0]}') |
|
|
|
self.attr_embedder.train() |
|
self.parsing_encoder.train() |
|
self.parsing_decoder.train() |
|
return (acc / num).item() |
|
|
|
def get_current_log(self): |
|
return self.log_dict |
|
|
|
def update_learning_rate(self, epoch): |
|
"""Update learning rate. |
|
|
|
Args: |
|
current_iter (int): Current iteration. |
|
warmup_iter (int): Warmup iter numbers. -1 for no warmup. |
|
Default: -1. |
|
""" |
|
lr = self.optimizer.param_groups[0]['lr'] |
|
|
|
if self.opt['lr_decay'] == 'step': |
|
lr = self.opt['lr'] * ( |
|
self.opt['gamma']**(epoch // self.opt['step'])) |
|
elif self.opt['lr_decay'] == 'cos': |
|
lr = self.opt['lr'] * ( |
|
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 |
|
elif self.opt['lr_decay'] == 'linear': |
|
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) |
|
elif self.opt['lr_decay'] == 'linear2exp': |
|
if epoch < self.opt['turning_point'] + 1: |
|
|
|
|
|
lr = self.opt['lr'] * ( |
|
1 - epoch / int(self.opt['turning_point'] * 1.0526)) |
|
else: |
|
lr *= self.opt['gamma'] |
|
elif self.opt['lr_decay'] == 'schedule': |
|
if epoch in self.opt['schedule']: |
|
lr *= self.opt['gamma'] |
|
else: |
|
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) |
|
|
|
for param_group in self.optimizer.param_groups: |
|
param_group['lr'] = lr |
|
|
|
return lr |
|
|
|
def save_network(self, save_path): |
|
"""Save networks. |
|
""" |
|
|
|
save_dict = {} |
|
save_dict['embedder'] = self.attr_embedder.state_dict() |
|
save_dict['encoder'] = self.parsing_encoder.state_dict() |
|
save_dict['decoder'] = self.parsing_decoder.state_dict() |
|
|
|
torch.save(save_dict, save_path) |
|
|
|
def load_network(self): |
|
checkpoint = torch.load(self.opt['pretrained_parsing_gen']) |
|
|
|
self.attr_embedder.load_state_dict(checkpoint['embedder'], strict=True) |
|
self.attr_embedder.eval() |
|
|
|
self.parsing_encoder.load_state_dict( |
|
checkpoint['encoder'], strict=True) |
|
self.parsing_encoder.eval() |
|
|
|
self.parsing_decoder.load_state_dict( |
|
checkpoint['decoder'], strict=True) |
|
self.parsing_decoder.eval() |
|
|
|
def palette_result(self, result): |
|
seg = result[0] |
|
palette = np.array(self.palette) |
|
assert palette.shape[1] == 3 |
|
assert len(palette.shape) == 2 |
|
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) |
|
for label, color in enumerate(palette): |
|
color_seg[seg == label, :] = color |
|
|
|
color_seg = color_seg[..., ::-1] |
|
return color_seg |
|
|