|
import logging |
|
import os |
|
import pickle |
|
import random |
|
import zipfile |
|
from typing import Any |
|
|
|
import numpy as np |
|
import psutil |
|
import torch |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def set_seed(seed: int = 1234) -> None: |
|
"""Sets the random seed. |
|
|
|
Args: |
|
seed: seed value |
|
""" |
|
|
|
random.seed(seed) |
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = False |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
def set_environment(cfg): |
|
"""Sets and checks environment settings""" |
|
if "GPT" in cfg.prediction.metric and os.getenv("OPENAI_API_KEY", "") == "": |
|
logger.warning("No OpenAI API Key set. Setting metric to BLEU. ") |
|
cfg.prediction.metric = "BLEU" |
|
return cfg |
|
|
|
|
|
def kill_child_processes(parent_pid: int) -> bool: |
|
"""Killing a process and all its child processes |
|
|
|
Args: |
|
parent_pid: process id of parent |
|
|
|
Returns: |
|
True or False in case of success or failure |
|
""" |
|
|
|
logger.debug(f"Killing process id: {parent_pid}") |
|
|
|
try: |
|
parent = psutil.Process(parent_pid) |
|
if parent.status() == "zombie": |
|
return False |
|
children = parent.children(recursive=True) |
|
for child in children: |
|
child.kill() |
|
parent.kill() |
|
return True |
|
except psutil.NoSuchProcess: |
|
logger.warning(f"Cannot kill process id: {parent_pid}. No such process.") |
|
return False |
|
|
|
|
|
def kill_ddp_processes() -> None: |
|
""" |
|
Killing all DDP processes from a single process. |
|
Firstly kills all children of a single DDP process (dataloader workers) |
|
Then kills all other DDP processes |
|
Then kills main parent DDP process |
|
""" |
|
|
|
pid = os.getpid() |
|
parent_pid = os.getppid() |
|
|
|
current_process = psutil.Process(pid) |
|
children = current_process.children(recursive=True) |
|
for child in children: |
|
child.kill() |
|
|
|
parent_process = psutil.Process(parent_pid) |
|
children = parent_process.children(recursive=True)[::-1] |
|
for child in children: |
|
if child.pid == pid: |
|
continue |
|
child.kill() |
|
parent_process.kill() |
|
current_process.kill() |
|
|
|
|
|
def add_file_to_zip(zf: zipfile.ZipFile, path: str) -> None: |
|
"""Adds a file to the existing zip. Does nothing if file does not exist. |
|
|
|
Args: |
|
zf: zipfile object to add to |
|
path: path to the file to add |
|
""" |
|
|
|
try: |
|
zf.write(path, os.path.basename(path)) |
|
except Exception: |
|
logger.warning(f"File {path} could not be added to zip.") |
|
|
|
|
|
def save_pickle(path: str, obj: Any, protocol: int = 4) -> None: |
|
"""Saves object as pickle file |
|
|
|
Args: |
|
path: path of file to save |
|
obj: object to save |
|
protocol: protocol to use when saving pickle |
|
""" |
|
|
|
with open(path, "wb") as pickle_file: |
|
pickle.dump(obj, pickle_file, protocol=protocol) |
|
|
|
|
|
class DisableLogger: |
|
def __init__(self, level: int = logging.INFO): |
|
self.level = level |
|
|
|
def __enter__(self): |
|
logging.disable(self.level) |
|
|
|
def __exit__(self, exit_type, exit_value, exit_traceback): |
|
logging.disable(logging.NOTSET) |
|
|
|
|
|
class PatchedAttribute: |
|
""" |
|
Patches an attribute of an object for the duration of this context manager. |
|
Similar to unittest.mock.patch, |
|
but works also for properties that are not present in the original class |
|
|
|
>>> class MyObj: |
|
... attr = 'original' |
|
>>> my_obj = MyObj() |
|
>>> with PatchedAttribute(my_obj, 'attr', 'patched'): |
|
... print(my_obj.attr) |
|
patched |
|
>>> print(my_obj.attr) |
|
original |
|
>>> with PatchedAttribute(my_obj, 'new_attr', 'new_patched'): |
|
... print(my_obj.new_attr) |
|
new_patched |
|
>>> assert not hasattr(my_obj, 'new_attr') |
|
""" |
|
|
|
def __init__(self, obj, attribute, new_value): |
|
self.obj = obj |
|
self.attribute = attribute |
|
self.new_value = new_value |
|
self.original_exists = hasattr(obj, attribute) |
|
if self.original_exists: |
|
self.original_value = getattr(obj, attribute) |
|
|
|
def __enter__(self): |
|
setattr(self.obj, self.attribute, self.new_value) |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
if self.original_exists: |
|
setattr(self.obj, self.attribute, self.original_value) |
|
else: |
|
delattr(self.obj, self.attribute) |
|
|