import os # Set this before importing any other modules to be on the safe side os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["VECLIB_MAXIMUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" import argparse import logging import sys import time import psutil def check_for_done(process_queue): """Checks for finished process ids Args: process_queue: list of process ids Returns: (True, process_idx) if there is any finished process (False, False) if there is not finished processes """ for i, pid in enumerate(process_queue): zombie = False try: p = psutil.Process(pid) zombie = p.status() == "zombie" except psutil.NoSuchProcess: pass if not psutil.pid_exists(pid) or zombie: return True, i return False, False if __name__ == "__main__": parser = argparse.ArgumentParser(description="") parser.add_argument( "-C", "--config", help="config filename", default=argparse.SUPPRESS ) parser.add_argument("-Y", "--yaml", help="yaml filename", default=argparse.SUPPRESS) parser.add_argument( "-Q", "--process-queue", help="process queue to wait for", default=argparse.SUPPRESS, ) parser_args, _ = parser.parse_known_args(sys.argv) process_queue = [] if "process_queue" in parser_args and parser_args.process_queue != "": process_queue = [int(x) for x in parser_args.process_queue.split(",")] while True: if len(process_queue) == 0: break done, num = check_for_done(process_queue) if done: process_queue.pop(num) else: time.sleep(30) # delayed imports from llm_studio, only after we want to start training import subprocess import torch from llm_studio.src.utils.config_utils import load_config_py, load_config_yaml from llm_studio.src.utils.exceptions import ( LLMAugmentationsException, LLMDataException, LLMMetricException, LLMModelException, LLMTrainingException, ) from llm_studio.src.utils.gpu_utils import is_oom_error from llm_studio.src.utils.logging_utils import initialize_logging, write_flag from llm_studio.src.utils.utils import kill_ddp_processes from train import run if "config" in parser_args: cfg = load_config_py(parser_args.config) elif "yaml" in parser_args: cfg = load_config_yaml(parser_args.yaml) flag_path = os.path.join(cfg.output_directory, "flags{}.json") # Check if DDP if "WORLD_SIZE" in os.environ: local_rank = int(os.environ["LOCAL_RANK"]) if local_rank == 0: write_flag(flag_path.format(""), "status", "running") else: write_flag(flag_path.format(""), "status", "running") local_rank = 0 initialize_logging(cfg) try: run(cfg=cfg) except Exception as exception: write_flag(flag_path.format(local_rank), "status", "failed") if is_oom_error(exception): logging.error( "GPU Out-of-Memory (OOM) error occurred. " "Please, reduce the batch size, or input data size, " "or model size. Or try gradient checkpointing.", exc_info=True, ) write_flag(flag_path.format(local_rank), "info", "OOM error") logging.info( "
"
                + subprocess.check_output(["nvidia-smi"]).decode("utf-8")
                + "
" ) if torch.cuda.is_available(): logging.info( "
" + torch.cuda.memory_summary().replace("-", "=") + "
" ) elif isinstance(exception, LLMDataException): logging.error( "Data error occurred during H2O LLM Studio run:", exc_info=True ) write_flag(flag_path.format(local_rank), "info", "Data error") elif isinstance(exception, LLMTrainingException): logging.error( "Training error occurred during H2O LLM Studio run:", exc_info=True ) write_flag(flag_path.format(local_rank), "info", "Training error") elif isinstance(exception, LLMMetricException): logging.error( "Validation metric failed. Please make sure selected validation " "metric is suitable for your current problem setup.", exc_info=True, ) write_flag(flag_path.format(local_rank), "info", "Metric error") elif isinstance(exception, LLMAugmentationsException): logging.error( "Custom augmentations error occurred during " "H2O LLM Studio run:", exc_info=True, ) write_flag(flag_path.format(local_rank), "info", "Augmentations error") elif isinstance(exception, LLMModelException): logging.error( "Model error occurred during H2O LLM Studio run:", exc_info=True, ) write_flag(flag_path.format(local_rank), "info", "Model error") else: logging.error( "Exception occurred during H2O LLM Studio run:", exc_info=True ) write_flag(flag_path.format(local_rank), "info", "See logs") kill_ddp_processes()