Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Utility functions | |
# | |
# @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
import importlib | |
import logging | |
import sys | |
import types | |
from datetime import timedelta | |
from timeit import default_timer as timer | |
from typing import Dict, List, Optional, Tuple, Union | |
from utils.__init__ import logger | |
# Helper timing functions | |
def start_timer() -> float: | |
"""Returns the number of seconds passed since epoch. The epoch is the point where the time starts, | |
and is platform dependent. | |
Returns: | |
float: The number of seconds passed since epoch | |
""" | |
return timer() | |
def end_timer(start_time: float, timed_event: str = "Time usage") -> None: | |
"""Prints the time passed from start_time. | |
Args: | |
start_time (float): The number of seconds passed since epoch when the timer started | |
timed_event (str, optional): A string describing the activity being monitored. Defaults to "Time usage". | |
""" | |
logger.info(f"{timed_event}: {timedelta(seconds=timer() - start_time)}") | |
def module_exists( | |
*names: Union[List[str], str], | |
error: str = "ignore", | |
warn_every_time: bool = False, | |
__INSTALLED_OPTIONAL_MODULES: Dict[str, bool] = {}, | |
) -> Optional[Union[Tuple[types.ModuleType, ...], types.ModuleType]]: | |
"""Try to import optional dependencies. | |
Ref: https://stackoverflow.com/a/73838546/4900327 | |
Args: | |
names (Union(List(str), str)): The module name(s) to import. Str or list of strings. | |
error (str, optional): What to do when a dependency is not found: | |
* raise : Raise an ImportError. | |
* warn: print a warning. | |
* ignore: If any module is not installed, return None, otherwise, return the module(s). | |
Defaults to "ignore". | |
warn_every_time (bool, optional): Whether to warn every time an import is tried. Only applies when error="warn". | |
Setting this to True will result in multiple warnings if you try to import the same library multiple times. | |
Defaults to False. | |
Raises: | |
ImportError: ImportError of Module | |
Returns: | |
Optional[ModuleType, Tuple[ModuleType...]]: The imported module(s), if all are found. | |
None is returned if any module is not found and `error!="raise"`. | |
""" | |
assert error in {"raise", "warn", "ignore"} | |
if isinstance(names, (list, tuple, set)): | |
names: List[str] = list(names) | |
else: | |
assert isinstance(names, str) | |
names: List[str] = [names] | |
modules = [] | |
for name in names: | |
try: | |
module = importlib.import_module(name) | |
modules.append(module) | |
__INSTALLED_OPTIONAL_MODULES[name] = True | |
except ImportError: | |
modules.append(None) | |
def error_msg(missing: Union[str, List[str]]): | |
if not isinstance(missing, (list, tuple)): | |
missing = [missing] | |
missing_str: str = " ".join([f'"{name}"' for name in missing]) | |
dep_str = "dependencies" | |
if len(missing) == 1: | |
dep_str = "dependency" | |
msg = f"Missing optional {dep_str} {missing_str}. Use pip or conda to install." | |
return msg | |
missing_modules: List[str] = [ | |
name for name, module in zip(names, modules) if module is None | |
] | |
if len(missing_modules) > 0: | |
if error == "raise": | |
raise ImportError(error_msg(missing_modules)) | |
if error == "warn": | |
for name in missing_modules: | |
# Ensures warning is printed only once | |
if warn_every_time is True or name not in __INSTALLED_OPTIONAL_MODULES: | |
logger.warning(f"Warning: {error_msg(name)}") | |
__INSTALLED_OPTIONAL_MODULES[name] = False | |
return None | |
if len(modules) == 1: | |
return modules[0] | |
return tuple(modules) | |
def close_logger(logger: logging.Logger) -> None: | |
"""Closing a logger savely | |
Args: | |
logger (logging.Logger): Logger to close | |
""" | |
handlers = logger.handlers[:] | |
for handler in handlers: | |
logger.removeHandler(handler) | |
handler.close() | |
logger.handlers.clear() | |
logging.shutdown() | |
class AverageMeter(object): | |
"""Computes and stores the average and current value | |
Original-Code: https://github.com/facebookresearch/simsiam | |
""" | |
def __init__(self, name, fmt=":f"): | |
self.name = name | |
self.fmt = fmt | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def __str__(self): | |
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" | |
return fmtstr.format(**self.__dict__) | |
def flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict: | |
"""Flatten a nested dictionary and insert the sep to seperate keys | |
Args: | |
d (dict): dict to flatten | |
parent_key (str, optional): parent key name. Defaults to ''. | |
sep (str, optional): Seperator. Defaults to '.'. | |
Returns: | |
dict: Flattened dict | |
""" | |
items = [] | |
for k, v in d.items(): | |
new_key = parent_key + sep + k if parent_key else k | |
if isinstance(v, dict): | |
items.extend(flatten_dict(v, new_key, sep=sep).items()) | |
else: | |
items.append((new_key, v)) | |
return dict(items) | |
def unflatten_dict(d: dict, sep: str = ".") -> dict: | |
"""Unflatten a flattened dictionary (created a nested dictionary) | |
Args: | |
d (dict): Dict to be nested | |
sep (str, optional): Seperator of flattened keys. Defaults to '.'. | |
Returns: | |
dict: Nested dict | |
""" | |
output_dict = {} | |
for key, value in d.items(): | |
keys = key.split(sep) | |
d = output_dict | |
for k in keys[:-1]: | |
d = d.setdefault(k, {}) | |
d[keys[-1]] = value | |
return output_dict | |
def remove_parameter_tag(d: dict, sep: str = ".") -> dict: | |
"""Remove all paramter tags from dictionary | |
Args: | |
d (dict): Dict must be flattened with defined seperator | |
sep (str, optional): Seperator used during flattening. Defaults to ".". | |
Returns: | |
dict: Dict with parameter tag removed | |
""" | |
param_dict = {} | |
for k, _ in d.items(): | |
unflattened_keys = k.split(sep) | |
new_keys = [] | |
max_num_insert = len(unflattened_keys) - 1 | |
for i, k in enumerate(unflattened_keys): | |
if i < max_num_insert and k != "parameters": | |
new_keys.append(k) | |
joined_key = sep.join(new_keys) | |
param_dict[joined_key] = {} | |
print(param_dict) | |
for k, v in d.items(): | |
unflattened_keys = k.split(sep) | |
new_keys = [] | |
max_num_insert = len(unflattened_keys) - 1 | |
for i, k in enumerate(unflattened_keys): | |
if i < max_num_insert and k != "parameters": | |
new_keys.append(k) | |
joined_key = sep.join(new_keys) | |
param_dict[joined_key][unflattened_keys[-1]] = v | |
return param_dict | |
def get_size_of_dict(d: dict) -> int: | |
size = sys.getsizeof(d) | |
for key, value in d.items(): | |
size += sys.getsizeof(key) | |
size += sys.getsizeof(value) | |
return size |