xierui.0097
Add application file
f0e9666
raw
history blame
2.9 kB
# Copyright (c) Alibaba, Inc. and its affiliates.
import importlib
import logging
from typing import Optional
from torch import distributed as dist
init_loggers = {}
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
def get_logger(log_file: Optional[str] = None,
log_level: int = logging.INFO,
file_mode: str = 'w'):
""" Get logging logger
Args:
log_file: Log filename, if specified, file handler will be added to
logger
log_level: Logging level.
file_mode: Specifies the mode to open the file, if filename is
specified (if filemode is unspecified, it defaults to 'w').
"""
logger_name = __name__.split('.')[0]
logger = logging.getLogger(logger_name)
logger.propagate = False
if logger_name in init_loggers:
add_file_handler_if_needed(logger, log_file, file_mode, log_level)
return logger
# handle duplicate logs to the console
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
# to the root logger. As logger.propagate is True by default, this root
# level handler causes logging messages from rank>0 processes to
# unexpectedly show up on the console, creating much unwanted clutter.
# To fix this issue, we set the root logger's StreamHandler, if any, to log
# at the ERROR level.
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
if importlib.util.find_spec('torch') is not None:
is_worker0 = is_master()
else:
is_worker0 = True
if is_worker0 and log_file is not None:
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
if is_worker0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
init_loggers[logger_name] = True
return logger
def add_file_handler_if_needed(logger, log_file, file_mode, log_level):
for handler in logger.handlers:
if isinstance(handler, logging.FileHandler):
return
if importlib.util.find_spec('torch') is not None:
is_worker0 = is_master()
else:
is_worker0 = True
if is_worker0 and log_file is not None:
file_handler = logging.FileHandler(log_file, file_mode)
file_handler.setFormatter(formatter)
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
def is_master(group=None):
return dist.get_rank(group) == 0 if is_dist() else True
def is_dist():
return dist.is_available() and dist.is_initialized()