Spaces:
Running
Running
import os.path | |
import torch | |
import torchvision | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
import pytorch_lightning as pl | |
import numpy as np | |
import sklearn | |
from sklearn.metrics import roc_curve, auc | |
from scipy.spatial.distance import cdist | |
from third_party.arcface.mouth_net import MouthNet | |
from third_party.arcface.margin_loss import Softmax, AMArcFace, AMCosFace | |
from third_party.arcface.load_dataset import MXFaceDataset, EvalDataset | |
from third_party.bisenet.bisenet import BiSeNet | |
class MouthNetPL(pl.LightningModule): | |
def __init__( | |
self, | |
num_classes: int, | |
batch_size: int = 256, | |
dim_feature: int = 128, | |
header_type: str = 'AMArcFace', | |
header_params: tuple = (64.0, 0.5, 0.0, 0.0), # (s, m, a, k) | |
rec_folder: str = "/gavin/datasets/msml/ms1m-retinaface", | |
learning_rate: int = 0.1, | |
crop: tuple = (0, 0, 112, 112), # (w1,h1,w2,h2) | |
): | |
super(MouthNetPL, self).__init__() | |
# self.img_size = (112, 112) | |
''' mouth feature extractor ''' | |
bisenet = BiSeNet(19) | |
bisenet.load_state_dict( | |
torch.load( | |
"/gavin/datasets/hanbang/79999_iter.pth", | |
map_location="cpu", | |
) | |
) | |
bisenet.eval() | |
bisenet.requires_grad_(False) | |
self.mouth_net = MouthNet( | |
bisenet=None, | |
feature_dim=dim_feature, | |
crop_param=crop, | |
iresnet_pretrained=False, | |
) | |
''' head & loss ''' | |
self.automatic_optimization = False | |
self.dim_feature = dim_feature | |
self.num_classes = num_classes | |
self._prepare_header(header_type, header_params) | |
self.cls_criterion = torch.nn.CrossEntropyLoss() | |
self.learning_rate = learning_rate | |
''' dataset ''' | |
assert os.path.exists(rec_folder) | |
self.rec_folder = rec_folder | |
self.batch_size = batch_size | |
self.crop_param = crop | |
''' validation ''' | |
def _prepare_header(self, head_type, header_params): | |
dim_in = self.dim_feature | |
dim_out = self.num_classes | |
""" Get hyper-params of header """ | |
s, m, a, k = header_params | |
""" Choose the header """ | |
if 'Softmax' in head_type: | |
self.classification = Softmax(dim_in, dim_out, device_id=None) | |
elif 'AMCosFace' in head_type: | |
self.classification = AMCosFace(dim_in, dim_out, | |
device_id=None, | |
s=s, m=m, | |
a=a, k=k, | |
) | |
elif 'AMArcFace' in head_type: | |
self.classification = AMArcFace(dim_in, dim_out, | |
device_id=None, | |
s=s, m=m, | |
a=a, k=k, | |
) | |
else: | |
raise ValueError('Header type error!') | |
def forward(self, x, label=None): | |
feat = self.mouth_net(x) | |
if self.training: | |
assert label is not None | |
cls = self.classification(feat, label) | |
return feat, cls | |
else: | |
return feat | |
def training_step(self, batch, batch_idx): | |
opt = self.optimizers(use_pl_optimizer=True) | |
img, label = batch | |
mouth_feat, final_cls = self(img, label) | |
cls_loss = self.cls_criterion(final_cls, label) | |
opt.zero_grad() | |
self.manual_backward(cls_loss) | |
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=5, norm_type=2) | |
opt.step() | |
''' loss logging ''' | |
self.logging_dict({"cls_loss": cls_loss}, prefix="train / ") | |
self.logging_lr() | |
if batch_idx % 50 == 0 and self.local_rank == 0: | |
print('loss=', cls_loss) | |
return cls_loss | |
def training_epoch_end(self, outputs): | |
sch = self.lr_schedulers() | |
sch.step() | |
lr = -1 | |
opts = self.trainer.optimizers | |
for opt in opts: | |
for param_group in opt.param_groups: | |
lr = param_group["lr"] | |
break | |
print('learning rate changed to %.6f' % lr) | |
# def validation_step(self, batch, batch_idx): | |
# return self.test_step(batch, batch_idx) | |
# | |
# def validation_step_end(self, outputs): | |
# return self.test_step_end(outputs) | |
# | |
# def validation_epoch_end(self, outputs): | |
# return self.test_step_end(outputs) | |
def save_tensor(tensor: torch.Tensor, path: str, b_idx: int = 0): | |
tensor = (tensor + 1.) * 127.5 | |
img = tensor.permute(0, 2, 3, 1)[b_idx].cpu().numpy() | |
from PIL import Image | |
img_pil = Image.fromarray(img.astype(np.uint8)) | |
img_pil.save(path) | |
def test_step(self, batch, batch_idx): | |
img1, img2, same = batch | |
feat1 = self.mouth_net(img1) | |
feat2 = self.mouth_net(img2) | |
return feat1, feat2, same | |
def test_step_end(self, outputs): | |
feat1, feat2, same = outputs | |
feat1 = feat1.cpu().numpy() | |
feat2 = feat2.cpu().numpy() | |
same = same.cpu().numpy() | |
feat1 = sklearn.preprocessing.normalize(feat1) | |
feat2 = sklearn.preprocessing.normalize(feat2) | |
predict_label = [] | |
num = feat1.shape[0] | |
for i in range(num): | |
dis_cos = cdist(feat1[i, None], feat2[i, None], metric='cosine') | |
predict_label.append(dis_cos[0, 0]) | |
predict_label = np.array(predict_label) | |
return { | |
"pred": predict_label, | |
"gt": same, | |
} | |
def test_epoch_end(self, outputs): | |
print(outputs) | |
pred, same = None, None | |
for batch_output in outputs: | |
if pred is None and same is None: | |
pred = batch_output["pred"] | |
same = batch_output["gt"] | |
else: | |
pred = np.concatenate([pred, batch_output["pred"]]) | |
same = np.concatenate([same, batch_output["gt"]]) | |
print(pred.shape, same.shape) | |
fpr, tpr, threshold = roc_curve(same, pred) | |
acc = tpr[np.argmin(np.abs(tpr - (1 - fpr)))] # choose proper threshold | |
print("=> verification finished, acc=%.4f" % (acc)) | |
''' save pth ''' | |
pth_path = "./weights/fixer_net_casia_%s.pth" % ('_'.join((str(x) for x in self.crop_param))) | |
self.mouth_net.save_backbone(pth_path) | |
print("=> model save to %s" % pth_path) | |
mouth_net = MouthNet( | |
bisenet=None, | |
feature_dim=self.dim_feature, | |
crop_param=self.crop_param | |
) | |
mouth_net.load_backbone(pth_path) | |
print("=> MouthNet pth checked") | |
return acc | |
def logging_dict(self, log_dict, prefix=None): | |
for key, val in log_dict.items(): | |
if prefix is not None: | |
key = prefix + key | |
self.log(key, val) | |
def logging_lr(self): | |
opts = self.trainer.optimizers | |
for idx, opt in enumerate(opts): | |
lr = None | |
for param_group in opt.param_groups: | |
lr = param_group["lr"] | |
break | |
self.log(f"lr_{idx}", lr) | |
def configure_optimizers(self): | |
params = list(self.parameters()) | |
learning_rate = self.learning_rate / 512 * self.batch_size * torch.cuda.device_count() | |
optimizer = torch.optim.SGD(params, lr=learning_rate, | |
momentum=0.9, weight_decay=5e-4) | |
print('lr is set as %.5f due to the global batch_size %d' % (learning_rate, | |
self.batch_size * torch.cuda.device_count())) | |
def lr_step_func(epoch): | |
return ((epoch + 1) / (4 + 1)) ** 2 if epoch < 0 else 0.1 ** len( | |
[m for m in [11, 17, 22] if m - 1 <= epoch]) # 0.1, 0.01, 0.001, 0.0001 | |
scheduler= torch.optim.lr_scheduler.LambdaLR( | |
optimizer=optimizer, lr_lambda=lr_step_func) | |
return [optimizer], [scheduler] | |
def train_dataloader(self): | |
dataset = MXFaceDataset( | |
root_dir=self.rec_folder, | |
crop_param=self.crop_param, | |
) | |
train_loader = DataLoader( | |
dataset, self.batch_size, num_workers=24, shuffle=True, drop_last=True | |
) | |
return train_loader | |
def val_dataloader(self): | |
return self.test_dataloader() | |
def test_dataloader(self): | |
dataset = EvalDataset( | |
rec_folder=self.rec_folder, | |
target='lfw', | |
crop_param=self.crop_param | |
) | |
test_loader = DataLoader( | |
dataset, 20, num_workers=12, shuffle=False, drop_last=False | |
) | |
return test_loader | |
def start_train(): | |
import os | |
import argparse | |
import torch | |
import pytorch_lightning as pl | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
import wandb | |
from pytorch_lightning.loggers import WandbLogger | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-g", | |
"--gpus", | |
type=str, | |
default=None, | |
help="Number of gpus to use (e.g. '0,1,2,3'). Will use all if not given.", | |
) | |
parser.add_argument("-n", "--name", type=str, required=True, help="Name of the run.") | |
parser.add_argument("-pj", "--project", type=str, default="mouthnet", help="Name of the project.") | |
parser.add_argument("-rp", "--resume_checkpoint_path", | |
type=str, default=None, help="path of checkpoint for resuming", ) | |
parser.add_argument("-p", "--saving_folder", | |
type=str, default="/apdcephfs/share_1290939/gavinyuan/out", help="saving folder", ) | |
parser.add_argument("--wandb_resume", | |
type=str, default=None, help="resume wandb logging from the input id", ) | |
parser.add_argument("--header_type", type=str, default="AMArcFace", help="loss type.") | |
parser.add_argument("-bs", "--batch_size", type=int, default=128, help="bs.") | |
parser.add_argument("-fs", "--fast_dev_run", type=bool, default=False, help="pytorch.lightning fast_dev_run") | |
args = parser.parse_args() | |
args.val_targets = [] | |
# args.rec_folder = "/gavin/datasets/msml/ms1m-retinaface" | |
# num_classes = 93431 | |
args.rec_folder = "/gavin/datasets/msml/casia" | |
num_classes = 10572 | |
save_path = os.path.join(args.saving_folder, args.name) | |
os.makedirs(save_path, exist_ok=True) | |
checkpoint_callback = ModelCheckpoint( | |
dirpath=save_path, | |
monitor="train / cls_loss", | |
save_top_k=10, | |
verbose=True, | |
every_n_train_steps=200, | |
) | |
torch.cuda.empty_cache() | |
mouth_net = MouthNetPL( | |
num_classes=num_classes, | |
batch_size=args.batch_size, | |
dim_feature=128, | |
rec_folder=args.rec_folder, | |
header_type=args.header_type, | |
crop=(28, 56, 84, 112) | |
) | |
if args.wandb_resume == None: | |
resume = "allow" | |
wandb_id = wandb.util.generate_id() | |
else: | |
resume = True | |
wandb_id = args.wandb_resume | |
logger = WandbLogger( | |
project=args.project, | |
entity="gavinyuan", | |
name=args.name, | |
resume=resume, | |
id=wandb_id, | |
) | |
trainer = pl.Trainer( | |
gpus=-1 if args.gpus is None else torch.cuda.device_count(), | |
callbacks=[checkpoint_callback], | |
logger=logger, | |
weights_save_path=save_path, | |
resume_from_checkpoint=args.resume_checkpoint_path, | |
gradient_clip_val=0, | |
max_epochs=25, | |
num_sanity_val_steps=1, | |
fast_dev_run=args.fast_dev_run, | |
val_check_interval=50, | |
progress_bar_refresh_rate=1, | |
distributed_backend="ddp", | |
benchmark=True, | |
) | |
trainer.fit(mouth_net) | |
if __name__ == "__main__": | |
start_train() | |