File size: 885 Bytes
2e23827 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import json
import yaml
import logging
import os
import numpy as np
import sys
def load_loss_scheme(loss_config):
with open(loss_config, 'r') as f:
loss_json = yaml.safe_load(f)
return loss_json
DEBUG =0
logger = logging.getLogger()
if DEBUG:
#coloredlogs.install(level='DEBUG')
logger.setLevel(logging.DEBUG)
else:
#coloredlogs.install(level='INFO')
logger.setLevel(logging.INFO)
strhdlr = logging.StreamHandler()
logger.addHandler(strhdlr)
formatter = logging.Formatter('%(asctime)s [%(filename)s:%(lineno)d] %(levelname)s %(message)s')
strhdlr.setFormatter(formatter)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def check_path(path):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
|