elineve's picture
Upload 301 files
07423df
import logging
import os
import pickle
import random
import zipfile
from typing import Any
import numpy as np
import psutil
import torch
logger = logging.getLogger(__name__)
def set_seed(seed: int = 1234) -> None:
"""Sets the random seed.
Args:
seed: seed value
"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def set_environment(cfg):
"""Sets and checks environment settings"""
if "GPT" in cfg.prediction.metric and os.getenv("OPENAI_API_KEY", "") == "":
logger.warning("No OpenAI API Key set. Setting metric to BLEU. ")
cfg.prediction.metric = "BLEU"
return cfg
def kill_child_processes(parent_pid: int) -> bool:
"""Killing a process and all its child processes
Args:
parent_pid: process id of parent
Returns:
True or False in case of success or failure
"""
logger.debug(f"Killing process id: {parent_pid}")
try:
parent = psutil.Process(parent_pid)
if parent.status() == "zombie":
return False
children = parent.children(recursive=True)
for child in children:
child.kill()
parent.kill()
return True
except psutil.NoSuchProcess:
logger.warning(f"Cannot kill process id: {parent_pid}. No such process.")
return False
def kill_ddp_processes() -> None:
"""
Killing all DDP processes from a single process.
Firstly kills all children of a single DDP process (dataloader workers)
Then kills all other DDP processes
Then kills main parent DDP process
"""
pid = os.getpid()
parent_pid = os.getppid()
current_process = psutil.Process(pid)
children = current_process.children(recursive=True)
for child in children:
child.kill()
parent_process = psutil.Process(parent_pid)
children = parent_process.children(recursive=True)[::-1]
for child in children:
if child.pid == pid:
continue
child.kill()
parent_process.kill()
current_process.kill()
def add_file_to_zip(zf: zipfile.ZipFile, path: str) -> None:
"""Adds a file to the existing zip. Does nothing if file does not exist.
Args:
zf: zipfile object to add to
path: path to the file to add
"""
try:
zf.write(path, os.path.basename(path))
except Exception:
logger.warning(f"File {path} could not be added to zip.")
def save_pickle(path: str, obj: Any, protocol: int = 4) -> None:
"""Saves object as pickle file
Args:
path: path of file to save
obj: object to save
protocol: protocol to use when saving pickle
"""
with open(path, "wb") as pickle_file:
pickle.dump(obj, pickle_file, protocol=protocol)
class DisableLogger:
def __init__(self, level: int = logging.INFO):
self.level = level
def __enter__(self):
logging.disable(self.level)
def __exit__(self, exit_type, exit_value, exit_traceback):
logging.disable(logging.NOTSET)
class PatchedAttribute:
"""
Patches an attribute of an object for the duration of this context manager.
Similar to unittest.mock.patch,
but works also for properties that are not present in the original class
>>> class MyObj:
... attr = 'original'
>>> my_obj = MyObj()
>>> with PatchedAttribute(my_obj, 'attr', 'patched'):
... print(my_obj.attr)
patched
>>> print(my_obj.attr)
original
>>> with PatchedAttribute(my_obj, 'new_attr', 'new_patched'):
... print(my_obj.new_attr)
new_patched
>>> assert not hasattr(my_obj, 'new_attr')
"""
def __init__(self, obj, attribute, new_value):
self.obj = obj
self.attribute = attribute
self.new_value = new_value
self.original_exists = hasattr(obj, attribute)
if self.original_exists:
self.original_value = getattr(obj, attribute)
def __enter__(self):
setattr(self.obj, self.attribute, self.new_value)
def __exit__(self, exc_type, exc_val, exc_tb):
if self.original_exists:
setattr(self.obj, self.attribute, self.original_value)
else:
delattr(self.obj, self.attribute)