Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import importlib | |
import logging | |
from typing import Optional | |
from torch import distributed as dist | |
init_loggers = {} | |
formatter = logging.Formatter( | |
'%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
def get_logger(log_file: Optional[str] = None, | |
log_level: int = logging.INFO, | |
file_mode: str = 'w'): | |
""" Get logging logger | |
Args: | |
log_file: Log filename, if specified, file handler will be added to | |
logger | |
log_level: Logging level. | |
file_mode: Specifies the mode to open the file, if filename is | |
specified (if filemode is unspecified, it defaults to 'w'). | |
""" | |
logger_name = __name__.split('.')[0] | |
logger = logging.getLogger(logger_name) | |
logger.propagate = False | |
if logger_name in init_loggers: | |
add_file_handler_if_needed(logger, log_file, file_mode, log_level) | |
return logger | |
# handle duplicate logs to the console | |
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET) | |
# to the root logger. As logger.propagate is True by default, this root | |
# level handler causes logging messages from rank>0 processes to | |
# unexpectedly show up on the console, creating much unwanted clutter. | |
# To fix this issue, we set the root logger's StreamHandler, if any, to log | |
# at the ERROR level. | |
for handler in logger.root.handlers: | |
if type(handler) is logging.StreamHandler: | |
handler.setLevel(logging.ERROR) | |
stream_handler = logging.StreamHandler() | |
handlers = [stream_handler] | |
if importlib.util.find_spec('torch') is not None: | |
is_worker0 = is_master() | |
else: | |
is_worker0 = True | |
if is_worker0 and log_file is not None: | |
file_handler = logging.FileHandler(log_file, file_mode) | |
handlers.append(file_handler) | |
for handler in handlers: | |
handler.setFormatter(formatter) | |
handler.setLevel(log_level) | |
logger.addHandler(handler) | |
if is_worker0: | |
logger.setLevel(log_level) | |
else: | |
logger.setLevel(logging.ERROR) | |
init_loggers[logger_name] = True | |
return logger | |
def add_file_handler_if_needed(logger, log_file, file_mode, log_level): | |
for handler in logger.handlers: | |
if isinstance(handler, logging.FileHandler): | |
return | |
if importlib.util.find_spec('torch') is not None: | |
is_worker0 = is_master() | |
else: | |
is_worker0 = True | |
if is_worker0 and log_file is not None: | |
file_handler = logging.FileHandler(log_file, file_mode) | |
file_handler.setFormatter(formatter) | |
file_handler.setLevel(log_level) | |
logger.addHandler(file_handler) | |
def is_master(group=None): | |
return dist.get_rank(group) == 0 if is_dist() else True | |
def is_dist(): | |
return dist.is_available() and dist.is_initialized() |